dofbi commited on
Commit
3ef36df
·
1 Parent(s): cbfbd37
TTS/tts/datasets/formatters.py CHANGED
@@ -80,6 +80,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
80
  {
81
  "text": row.text,
82
  "audio_file": audio_path,
 
83
  "speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
84
  "emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
85
  "root_path": root_path,
 
80
  {
81
  "text": row.text,
82
  "audio_file": audio_path,
83
+ "ref_file": "null" if "ref_file" not in metadata.columns else os.path.join(root_path, row.ref_file),
84
  "speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
85
  "emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
86
  "root_path": root_path,
TTS/tts/layers/xtts/gpt.py CHANGED
@@ -184,6 +184,63 @@ class GPT(nn.Module):
184
  # XTTS v1
185
  self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
186
  self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  def get_grad_norm_parameter_groups(self):
189
  return {
 
184
  # XTTS v1
185
  self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
186
  self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
187
+
188
+ def resize_text_embeddings(self, new_num_tokens: int):
189
+
190
+ old_embeddings_requires_grad = self.text_embedding.weight.requires_grad
191
+
192
+ old_num_tokens, old_embedding_dim = self.text_embedding.weight.size()
193
+ if old_num_tokens == new_num_tokens:
194
+ return
195
+
196
+ new_embeddings = nn.Embedding(
197
+ new_num_tokens,
198
+ old_embedding_dim,
199
+ device=self.text_embedding.weight.device,
200
+ dtype=self.text_embedding.weight.dtype,
201
+ )
202
+
203
+ # numbers of tokens to copy
204
+ n = min(old_num_tokens, new_num_tokens)
205
+
206
+ new_embeddings.weight.data[:n, :] = self.text_embedding.weight.data[:n, :]
207
+
208
+ self.text_embedding.weight.data = new_embeddings.weight.data
209
+ self.text_embedding.num_embeddings = new_embeddings.weight.data.shape[0]
210
+ if self.text_embedding.padding_idx is not None and (new_num_tokens - 1) < self.text_embedding.padding_idx:
211
+ self.text_embedding.padding_idx = None
212
+
213
+
214
+ self.text_embedding.requires_grad_(old_embeddings_requires_grad)
215
+
216
+ def resize_text_head(self, new_num_tokens: int):
217
+ old_lm_head_requires_grad = self.text_head.weight.requires_grad
218
+
219
+ old_num_tokens, old_lm_head_dim = self.text_head.weight.size()
220
+
221
+ new_lm_head_shape = (old_lm_head_dim, new_num_tokens)
222
+ has_new_lm_head_bias = self.text_head.bias is not None
223
+
224
+ new_lm_head = nn.Linear(
225
+ *new_lm_head_shape,
226
+ bias=has_new_lm_head_bias,
227
+ device=self.text_head.weight.device,
228
+ dtype=self.text_head.weight.dtype,
229
+ )
230
+
231
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
232
+
233
+ new_lm_head.weight.data[:num_tokens_to_copy, :] = self.text_head.weight.data[:num_tokens_to_copy, :]
234
+
235
+ # Copy bias weights to new lm head
236
+ if has_new_lm_head_bias:
237
+ new_lm_head.bias.data[:num_tokens_to_copy] = self.text_head.bias.data[:num_tokens_to_copy]
238
+
239
+ self.text_head = new_lm_head
240
+
241
+ self.text_head.requires_grad_(old_lm_head_requires_grad)
242
+ pass
243
+
244
 
245
  def get_grad_norm_parameter_groups(self):
246
  return {
TTS/tts/layers/xtts/tokenizer.py CHANGED
@@ -621,7 +621,7 @@ class VoiceBpeTokenizer:
621
 
622
  def check_input_length(self, txt, lang):
623
  lang = lang.split("-")[0] # remove the region
624
- limit = self.char_limits.get(lang, 250)
625
  if len(txt) > limit:
626
  print(
627
  f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
@@ -640,7 +640,8 @@ class VoiceBpeTokenizer:
640
  # @manmay will implement this
641
  txt = basic_cleaners(txt)
642
  else:
643
- raise NotImplementedError(f"Language '{lang}' is not supported.")
 
644
  return txt
645
 
646
  def encode(self, txt, lang):
 
621
 
622
  def check_input_length(self, txt, lang):
623
  lang = lang.split("-")[0] # remove the region
624
+ limit = self.char_limits.get(lang, 300)
625
  if len(txt) > limit:
626
  print(
627
  f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
 
640
  # @manmay will implement this
641
  txt = basic_cleaners(txt)
642
  else:
643
+ txt = basic_cleaners(txt)
644
+ # print(f"[!] Warning: Preprocess [Language '{lang}'] text is not implemented, use `basic_cleaners` instead.")
645
  return txt
646
 
647
  def encode(self, txt, lang):
TTS/tts/layers/xtts/trainer/dataset.py CHANGED
@@ -23,29 +23,41 @@ def key_samples_by_col(samples, col):
23
  return samples_by_col
24
 
25
 
26
- def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False):
27
- rel_clip = load_audio(gt_path, sample_rate)
28
- # if eval uses a middle size sample when it is possible to be more reproducible
29
- if is_eval:
30
- sample_length = int((min_sample_length + max_sample_length) / 2)
31
- else:
32
- sample_length = random.randint(min_sample_length, max_sample_length)
33
- gap = rel_clip.shape[-1] - sample_length
34
- if gap < 0:
35
- sample_length = rel_clip.shape[-1] // 2
36
- gap = rel_clip.shape[-1] - sample_length
37
-
38
- # if eval start always from the position 0 to be more reproducible
39
- if is_eval:
40
- rand_start = 0
 
 
 
 
 
 
 
 
 
41
  else:
42
- rand_start = random.randint(0, gap)
 
 
 
 
 
 
 
 
43
 
44
- rand_end = rand_start + sample_length
45
- rel_clip = rel_clip[:, rand_start:rand_end]
46
- rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
47
- cond_idxs = [rand_start, rand_end]
48
- return rel_clip, rel_clip.shape[-1], cond_idxs
49
 
50
 
51
  class XTTSDataset(torch.utils.data.Dataset):
@@ -110,14 +122,14 @@ class XTTSDataset(torch.utils.data.Dataset):
110
  wav = load_audio(audiopath, self.sample_rate)
111
  if text is None or len(text.strip()) == 0:
112
  raise ValueError
113
- if wav is None or wav.shape[-1] < (0.5 * self.sample_rate):
114
  # Ultra short clips are also useless (and can cause problems within some models).
115
  raise ValueError
116
 
117
  if self.use_masking_gt_prompt_approach:
118
  # get a slice from GT to condition the model
119
  cond, _, cond_idxs = get_prompt_slice(
120
- audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
121
  )
122
  # if use masking do not use cond_len
123
  cond_len = torch.nan
@@ -128,7 +140,7 @@ class XTTSDataset(torch.utils.data.Dataset):
128
  else audiopath
129
  )
130
  cond, cond_len, _ = get_prompt_slice(
131
- ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
132
  )
133
  # if do not use masking use cond_len
134
  cond_idxs = torch.nan
 
23
  return samples_by_col
24
 
25
 
26
+ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False, ref_path="null"):
27
+ if ref_path == "null":
28
+ rel_clip = load_audio(gt_path, sample_rate)
29
+ # if eval uses a middle size sample when it is possible to be more reproducible
30
+ if is_eval:
31
+ sample_length = int((min_sample_length + max_sample_length) / 2)
32
+ else:
33
+ sample_length = random.randint(min_sample_length, max_sample_length)
34
+ gap = rel_clip.shape[-1] - sample_length
35
+ if gap < 0:
36
+ sample_length = rel_clip.shape[-1] // 2
37
+ gap = rel_clip.shape[-1] - sample_length
38
+
39
+ # if eval start always from the position 0 to be more reproducible
40
+ if is_eval:
41
+ rand_start = 0
42
+ else:
43
+ rand_start = random.randint(0, gap)
44
+
45
+ rand_end = rand_start + sample_length
46
+ rel_clip = rel_clip[:, rand_start:rand_end]
47
+ rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
48
+ cond_idxs = [rand_start, rand_end]
49
+ return rel_clip, rel_clip.shape[-1], cond_idxs
50
  else:
51
+ rel_clip = load_audio(ref_path, sample_rate)
52
+
53
+ sample_length = min(max_sample_length, rel_clip.shape[-1])
54
+
55
+ rel_clip = rel_clip[:, :sample_length]
56
+ rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
57
+ cond_idxs = [0, sample_length]
58
+ return rel_clip, rel_clip.shape[-1], cond_idxs
59
+
60
 
 
 
 
 
 
61
 
62
 
63
  class XTTSDataset(torch.utils.data.Dataset):
 
122
  wav = load_audio(audiopath, self.sample_rate)
123
  if text is None or len(text.strip()) == 0:
124
  raise ValueError
125
+ if wav is None or wav.shape[-1] < (0.2 * self.sample_rate):
126
  # Ultra short clips are also useless (and can cause problems within some models).
127
  raise ValueError
128
 
129
  if self.use_masking_gt_prompt_approach:
130
  # get a slice from GT to condition the model
131
  cond, _, cond_idxs = get_prompt_slice(
132
+ audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval, sample["ref_file"]
133
  )
134
  # if use masking do not use cond_len
135
  cond_len = torch.nan
 
140
  else audiopath
141
  )
142
  cond, cond_len, _ = get_prompt_slice(
143
+ ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval, sample["ref_file"]
144
  )
145
  # if do not use masking use cond_len
146
  cond_idxs = torch.nan
TTS/tts/layers/xtts/trainer/dvae_dataset.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from TTS.tts.models.xtts import load_audio
4
+
5
+ torch.set_num_threads(1)
6
+
7
+ def key_samples_by_col(samples, col):
8
+ """Returns a dictionary of samples keyed by language."""
9
+ samples_by_col = {}
10
+ for sample in samples:
11
+ col_val = sample[col]
12
+ assert isinstance(col_val, str)
13
+ if col_val not in samples_by_col:
14
+ samples_by_col[col_val] = []
15
+ samples_by_col[col_val].append(sample)
16
+ return samples_by_col
17
+
18
+ class DVAEDataset(torch.utils.data.Dataset):
19
+ def __init__(self, samples, sample_rate, is_eval, max_wav_len=255995):
20
+ self.sample_rate = sample_rate
21
+ self.is_eval = is_eval
22
+ self.max_wav_len = max_wav_len
23
+ self.samples = samples
24
+ self.training_seed = 1
25
+ self.failed_samples = set()
26
+ if not is_eval:
27
+ random.seed(self.training_seed)
28
+ # random.shuffle(self.samples)
29
+ random.shuffle(self.samples)
30
+ # order by language
31
+ self.samples = key_samples_by_col(self.samples, "language")
32
+ print(" > Sampling by language:", self.samples.keys())
33
+ else:
34
+ # for evaluation load and check samples that are corrupted to ensures the reproducibility
35
+ self.check_eval_samples()
36
+
37
+ def check_eval_samples(self):
38
+ print(" > Filtering invalid eval samples!!")
39
+ new_samples = []
40
+ for sample in self.samples:
41
+ try:
42
+ _, wav = self.load_item(sample)
43
+ except:
44
+ continue
45
+ # Basically, this audio file is nonexistent or too long to be supported by the dataset.
46
+ if (
47
+ wav is None
48
+ or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
49
+ ):
50
+ continue
51
+ new_samples.append(sample)
52
+ self.samples = new_samples
53
+ print(" > Total eval samples after filtering:", len(self.samples))
54
+
55
+ def load_item(self, sample):
56
+ audiopath = sample["audio_file"]
57
+ wav = load_audio(audiopath, self.sample_rate)
58
+ if wav is None or wav.shape[-1] < (0.5 * self.sample_rate):
59
+ # Ultra short clips are also useless (and can cause problems within some models).
60
+ raise ValueError
61
+
62
+ return audiopath, wav
63
+
64
+ def __getitem__(self, index):
65
+ if self.is_eval:
66
+ sample = self.samples[index]
67
+ sample_id = str(index)
68
+ else:
69
+ # select a random language
70
+ lang = random.choice(list(self.samples.keys()))
71
+ # select random sample
72
+ index = random.randint(0, len(self.samples[lang]) - 1)
73
+ sample = self.samples[lang][index]
74
+ # a unique id for each sampel to deal with fails
75
+ sample_id = lang + "_" + str(index)
76
+
77
+ # ignore samples that we already know that is not valid ones
78
+ if sample_id in self.failed_samples:
79
+ # call get item again to get other sample
80
+ return self[1]
81
+
82
+ # try to load the sample, if fails added it to the failed samples list
83
+ try:
84
+ audiopath, wav = self.load_item(sample)
85
+ except:
86
+ self.failed_samples.add(sample_id)
87
+ return self[1]
88
+
89
+ # check if the audio and text size limits and if it out of the limits, added it failed_samples
90
+ if (
91
+ wav is None
92
+ or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
93
+ ):
94
+ # Basically, this audio file is nonexistent or too long to be supported by the dataset.
95
+ # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
96
+ self.failed_samples.add(sample_id)
97
+ return self[1]
98
+
99
+ res = {
100
+ "wav": wav,
101
+ "wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
102
+ "filenames": audiopath,
103
+ }
104
+ return res
105
+
106
+ def __len__(self):
107
+ if self.is_eval:
108
+ return len(self.samples)
109
+ return sum([len(v) for v in self.samples.values()])
110
+
111
+ def collate_fn(self, batch):
112
+ # convert list of dicts to dict of lists
113
+ B = len(batch)
114
+
115
+ batch = {k: [dic[k] for dic in batch] for k in batch[0]}
116
+
117
+ # stack for features that already have the same shape
118
+ batch["wav_lengths"] = torch.stack(batch["wav_lengths"])
119
+
120
+ max_wav_len = batch["wav_lengths"].max()
121
+
122
+ # create padding tensors
123
+ wav_padded = torch.FloatTensor(B, 1, max_wav_len)
124
+
125
+ # initialize tensors for zero padding
126
+ wav_padded = wav_padded.zero_()
127
+ for i in range(B):
128
+ wav = batch["wav"][i]
129
+ wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)
130
+
131
+ batch["wav"] = wav_padded
132
+ return batch
TTS/tts/layers/xtts/trainer/gpt_trainer.py CHANGED
@@ -97,7 +97,8 @@ class GPTTrainer(BaseTTS):
97
  states_keys = list(gpt_checkpoint.keys())
98
  for key in states_keys:
99
  if "gpt." in key:
100
- new_key = key.replace("gpt.", "")
 
101
  gpt_checkpoint[new_key] = gpt_checkpoint[key]
102
  del gpt_checkpoint[key]
103
  else:
@@ -484,6 +485,40 @@ class GPTTrainer(BaseTTS):
484
 
485
  state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path)
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  # load the model weights
488
  self.xtts.load_state_dict(state, strict=strict)
489
 
 
97
  states_keys = list(gpt_checkpoint.keys())
98
  for key in states_keys:
99
  if "gpt." in key:
100
+ # new_key = key.replace("gpt.", "")
101
+ new_key = key[4:]
102
  gpt_checkpoint[new_key] = gpt_checkpoint[key]
103
  del gpt_checkpoint[key]
104
  else:
 
485
 
486
  state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path)
487
 
488
+ # edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
489
+ if (
490
+ "gpt.text_embedding.weight" in state
491
+ and state["gpt.text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape
492
+ ):
493
+ num_new_tokens = (
494
+ self.xtts.gpt.text_embedding.weight.shape[0] - state["gpt.text_embedding.weight"].shape[0]
495
+ )
496
+ print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
497
+
498
+ # add new tokens to a linear layer (text_head)
499
+ emb_g = state["gpt.text_embedding.weight"]
500
+ new_row = torch.randn(num_new_tokens, emb_g.shape[1])
501
+ start_token_row = emb_g[-1, :]
502
+ emb_g = torch.cat([emb_g, new_row], axis=0)
503
+ emb_g[-1, :] = start_token_row
504
+ state["gpt.text_embedding.weight"] = emb_g
505
+
506
+ # add new weights to the linear layer (text_head)
507
+ text_head_weight = state["gpt.text_head.weight"]
508
+ start_token_row = text_head_weight[-1, :]
509
+ new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1])
510
+ text_head_weight = torch.cat([text_head_weight, new_entry], axis=0)
511
+ text_head_weight[-1, :] = start_token_row
512
+ state["gpt.text_head.weight"] = text_head_weight
513
+
514
+ # add new biases to the linear layer (text_head)
515
+ text_head_bias = state["gpt.text_head.bias"]
516
+ start_token_row = text_head_bias[-1]
517
+ new_bias_entry = torch.zeros(num_new_tokens)
518
+ text_head_bias = torch.cat([text_head_bias, new_bias_entry], axis=0)
519
+ text_head_bias[-1] = start_token_row
520
+ state["gpt.text_head.bias"] = text_head_bias
521
+
522
  # load the model weights
523
  self.xtts.load_state_dict(state, strict=strict)
524
 
TTS/tts/models/xtts.py CHANGED
@@ -523,7 +523,7 @@ class Xtts(BaseTTS):
523
  gpt_cond_latent = gpt_cond_latent.to(self.device)
524
  speaker_embedding = speaker_embedding.to(self.device)
525
  if enable_text_splitting:
526
- text = split_sentence(text, language, self.tokenizer.char_limits[language])
527
  else:
528
  text = [text]
529
 
@@ -553,6 +553,7 @@ class Xtts(BaseTTS):
553
  output_attentions=False,
554
  **hf_generate_kwargs,
555
  )
 
556
  expected_output_len = torch.tensor(
557
  [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
558
  )
@@ -633,7 +634,7 @@ class Xtts(BaseTTS):
633
  gpt_cond_latent = gpt_cond_latent.to(self.device)
634
  speaker_embedding = speaker_embedding.to(self.device)
635
  if enable_text_splitting:
636
- text = split_sentence(text, language, self.tokenizer.char_limits[language])
637
  else:
638
  text = [text]
639
 
 
523
  gpt_cond_latent = gpt_cond_latent.to(self.device)
524
  speaker_embedding = speaker_embedding.to(self.device)
525
  if enable_text_splitting:
526
+ text = split_sentence(text, language, self.tokenizer.char_limits.get(language, 250))
527
  else:
528
  text = [text]
529
 
 
553
  output_attentions=False,
554
  **hf_generate_kwargs,
555
  )
556
+
557
  expected_output_len = torch.tensor(
558
  [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
559
  )
 
634
  gpt_cond_latent = gpt_cond_latent.to(self.device)
635
  speaker_embedding = speaker_embedding.to(self.device)
636
  if enable_text_splitting:
637
+ text = split_sentence(text, language, self.tokenizer.char_limits.get(language, 250))
638
  else:
639
  text = [text]
640
 
app.py CHANGED
@@ -6,45 +6,72 @@ import sys
6
  import soundfile as sf
7
  import numpy as np
8
  import logging
 
9
 
10
  # Configuration du logger
11
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
  #Chemin local de téléchargement des fichiers (il faut s'assurer que le dossier existe)
14
- LOCAL_DOWNLOAD_PATH = "./local_model"
15
-
16
  # Télécharger le script d'inférence
17
  repo_id = "dofbi/galsenai-xtts-v2-wolof-inference"
18
- inference_file = hf_hub_download(repo_id=repo_id, filename="inference.py",local_dir=LOCAL_DOWNLOAD_PATH)
19
 
20
  # Ajouter le dossier au chemin de recherche
21
  sys.path.insert(0, LOCAL_DOWNLOAD_PATH)
22
 
23
- # Importer la fonction à partir du script d'inférence téléchargé
24
- from inference import generate_audio
 
 
 
 
 
 
 
25
 
26
- def tts(text, audio_reference):
 
 
 
 
 
 
27
  logging.debug(f"tts function called with text: {text} and audio_reference: {audio_reference}")
28
- if text and audio_reference is not None:
29
- #Sauvegarde temporaire de l'audio reference
30
- temp_audio_path = "temp_audio_ref.wav"
31
- sf.write(temp_audio_path, audio_reference, 44100)
32
- logging.debug(f"Audio reference saved to {temp_audio_path}")
33
- audio_output, sample_rate = generate_audio(text, temp_audio_path, LOCAL_DOWNLOAD_PATH)
34
- logging.debug(f"Audio generated with sample rate: {sample_rate}")
35
- return (sample_rate, audio_output)
36
- else:
37
  logging.debug("Text or audio reference is missing")
38
  return "Veuillez entrer un texte et fournir un audio de référence."
39
 
40
- demo = gr.Interface(
41
- fn=tts,
42
- inputs=[
43
- gr.Textbox(label="Text to synthesize"),
44
- gr.Audio(type="numpy", label="Reference audio")
45
- ],
46
- outputs=gr.Audio(label="Synthesized audio"),
47
- )
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
50
  demo.launch()
 
6
  import soundfile as sf
7
  import numpy as np
8
  import logging
9
+ import tempfile
10
 
11
  # Configuration du logger
12
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
13
 
14
  #Chemin local de téléchargement des fichiers (il faut s'assurer que le dossier existe)
15
+ LOCAL_DOWNLOAD_PATH = os.path.dirname("/content") # Utiliser le chemin du script
 
16
  # Télécharger le script d'inférence
17
  repo_id = "dofbi/galsenai-xtts-v2-wolof-inference"
18
+ inference_file = hf_hub_download(repo_id=repo_id, filename="inference.py", local_dir=LOCAL_DOWNLOAD_PATH)
19
 
20
  # Ajouter le dossier au chemin de recherche
21
  sys.path.insert(0, LOCAL_DOWNLOAD_PATH)
22
 
23
+ # Importer la classe à partir du script d'inférence téléchargé
24
+ from inference import WolofXTTSInference
25
+
26
+ # Initialiser le modèle une seule fois
27
+ tts_model = WolofXTTSInference()
28
+
29
+ def tts(text: str, audio_reference: tuple[int, np.ndarray]) -> tuple[int, np.ndarray] | str:
30
+ """
31
+ Synthétise de la parole à partir d'un texte en utilisant un audio de référence.
32
 
33
+ Args:
34
+ text (str): Le texte à synthétiser.
35
+ audio_reference (tuple[int, np.ndarray]): Un tuple contenant le taux d'échantillonnage et les données audio de référence.
36
+
37
+ Returns:
38
+ tuple[int, np.ndarray] | str: un tuple contenant le taux d'échantillonnage et les données audio synthétisées, ou un message d'erreur.
39
+ """
40
  logging.debug(f"tts function called with text: {text} and audio_reference: {audio_reference}")
41
+
42
+ if not text or audio_reference is None:
 
 
 
 
 
 
 
43
  logging.debug("Text or audio reference is missing")
44
  return "Veuillez entrer un texte et fournir un audio de référence."
45
 
46
+ try:
47
+ sample_rate, audio_data = audio_reference
48
+
49
+ # Créer un fichier temporaire pour l'audio de référence
50
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio_file:
51
+ sf.write(temp_audio_file.name, audio_data, sample_rate)
52
+ logging.debug(f"Audio reference saved to {temp_audio_file.name}")
53
+
54
+ # Utiliser la méthode generate_audio de la nouvelle classe
55
+ audio_output, output_sample_rate = tts_model.generate_audio(
56
+ text,
57
+ reference_audio=temp_audio_file.name
58
+ )
59
+
60
+ logging.debug(f"Audio generated with sample rate: {output_sample_rate}")
61
+ return (output_sample_rate, audio_output)
62
+
63
+ except Exception as e:
64
+ logging.error(f"Error during audio generation: {e}")
65
+ return f"Une erreur s'est produite lors de la génération audio: {e}"
66
 
67
  if __name__ == "__main__":
68
+ demo = gr.Interface(
69
+ fn=tts,
70
+ inputs=[
71
+ gr.Textbox(label="Text to synthesize"),
72
+ gr.Audio(type="numpy", label="Reference audio")
73
+ ],
74
+ outputs=gr.Audio(label="Synthesized audio"),
75
+ )
76
+
77
  demo.launch()
local_model/__pycache__/inference.cpython-310.pyc ADDED
Binary file (5.07 kB). View file
 
local_model/inference.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import logging
4
+ import soundfile as sf
5
+ import numpy as np
6
+ from huggingface_hub import hf_hub_download
7
+ from TTS.tts.configs.xtts_config import XttsConfig
8
+ from TTS.tts.models.xtts import Xtts
9
+
10
+ # --- CONSTANTES ---
11
+ REPO_ID = "dofbi/galsenai-xtts-v2-wolof-inference"
12
+ LOCAL_DIR = "./models"
13
+
14
+ class WolofXTTSInference:
15
+ def __init__(self, repo_id=REPO_ID, local_dir=LOCAL_DIR):
16
+ # Configuration du logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(levelname)s - %(message)s'
20
+ )
21
+ self.logger = logging.getLogger(__name__)
22
+
23
+ # Créer le dossier local s'il n'existe pas
24
+ os.makedirs(local_dir, exist_ok=True)
25
+
26
+ # Téléchargement des fichiers nécessaires
27
+ try:
28
+ # Créer les sous-dossiers nécessaires
29
+ os.makedirs(os.path.join(local_dir, "Anta_GPT_XTTS_Wo"), exist_ok=True)
30
+ os.makedirs(os.path.join(local_dir, "XTTS_v2.0_original_model_files"), exist_ok=True)
31
+
32
+ # Télécharger le checkpoint
33
+ self.model_path = hf_hub_download(
34
+ repo_id=repo_id,
35
+ filename="Anta_GPT_XTTS_Wo/best_model_89250.pth",
36
+ local_dir=local_dir
37
+ )
38
+
39
+ # Télécharger le fichier de configuration
40
+ self.config_path = hf_hub_download(
41
+ repo_id=repo_id,
42
+ filename="Anta_GPT_XTTS_Wo/config.json",
43
+ local_dir=local_dir
44
+ )
45
+
46
+ # Télécharger le vocabulaire
47
+ self.vocab_path = hf_hub_download(
48
+ repo_id=repo_id,
49
+ filename="XTTS_v2.0_original_model_files/vocab.json",
50
+ local_dir=local_dir
51
+ )
52
+
53
+ # Télécharger l'audio de référence
54
+ self.reference_audio = hf_hub_download(
55
+ repo_id=repo_id,
56
+ filename="anta_sample.wav",
57
+ local_dir=local_dir
58
+ )
59
+
60
+ except Exception as e:
61
+ self.logger.error(f"Erreur lors du téléchargement des fichiers : {e}")
62
+ raise
63
+
64
+ # Sélection du device
65
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
66
+
67
+ # Initialisation du modèle
68
+ self.model = self._load_model()
69
+
70
+ def _load_model(self):
71
+ """Charge le modèle XTTS"""
72
+ try:
73
+ self.logger.info("Chargement du modèle XTTS...")
74
+
75
+ # Initialisation du modèle
76
+ config = XttsConfig()
77
+ config.load_json(self.config_path)
78
+ model = Xtts.init_from_config(config)
79
+
80
+ # Chargement du checkpoint avec load_checkpoint
81
+ model.load_checkpoint(config,
82
+ checkpoint_path=self.model_path,
83
+ vocab_path=self.vocab_path,
84
+ use_deepspeed=False
85
+ )
86
+
87
+ model.to(self.device)
88
+ model.eval() # Mettre le modèle en mode évaluation
89
+
90
+ self.logger.info("Modèle chargé avec succès!")
91
+ return model
92
+
93
+ except Exception as e:
94
+ self.logger.error(f"Erreur lors du chargement du modèle : {e}")
95
+ raise
96
+
97
+ def generate_audio(
98
+ self,
99
+ text: str,
100
+ reference_audio: str = None,
101
+ speed: float = 1.06,
102
+ language: str = "wo",
103
+ output_path: str = None
104
+ ) -> tuple[np.ndarray, int]:
105
+ """
106
+ Génère de l'audio à partir du texte fourni
107
+
108
+ Args:
109
+ text (str): Texte à convertir en audio
110
+ reference_audio (str, optional): Chemin vers l'audio de référence. Defaults to None.
111
+ speed (float, optional): Vitesse de lecture. Defaults to 1.06.
112
+ language (str, optional): Langue du texte. Defaults to "wo".
113
+ output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None.
114
+
115
+ Returns:
116
+ tuple[np.ndarray, int]: audio_array, sample_rate
117
+ """
118
+ if not text:
119
+ raise ValueError("Le texte ne peut pas être vide.")
120
+
121
+ try:
122
+ # Utiliser l'audio de référence fourni ou par défaut
123
+ ref_audio = reference_audio or self.reference_audio
124
+
125
+ # Obtenir les embeddings
126
+ gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents(
127
+ audio_path=[ref_audio],
128
+ gpt_cond_len=self.model.config.gpt_cond_len,
129
+ max_ref_length=self.model.config.max_ref_len,
130
+ sound_norm_refs=self.model.config.sound_norm_refs
131
+ )
132
+
133
+ # Génération de l'audio
134
+ result = self.model.inference(
135
+ text=text.lower(),
136
+ gpt_cond_latent=gpt_cond_latent,
137
+ speaker_embedding=speaker_embedding,
138
+ do_sample=False,
139
+ speed=speed,
140
+ language=language,
141
+ enable_text_splitting=True
142
+ )
143
+
144
+ # Récupérer le taux d'échantillonnage
145
+ sample_rate = self.model.config.audio.sample_rate
146
+
147
+ # Sauvegarde optionnelle
148
+ if output_path:
149
+ sf.write(output_path, result["wav"], sample_rate)
150
+ self.logger.info(f"Audio sauvegardé dans {output_path}")
151
+
152
+ return result["wav"], sample_rate
153
+
154
+ except Exception as e:
155
+ self.logger.error(f"Erreur lors de la génération de l'audio : {e}")
156
+ raise
157
+
158
+ def generate_audio_from_config(self, text: str, config: dict, output_path: str = None) -> tuple[np.ndarray, int]:
159
+ """
160
+ Génère de l'audio à partir du texte et d'un dictionnaire de configuration.
161
+
162
+ Args:
163
+ text (str): Texte à convertir en audio
164
+ config (dict): Dictionnaire de configuration (speed, language, reference_audio)
165
+ output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None.
166
+
167
+ Returns:
168
+ tuple[np.ndarray, int]: audio_array, sample_rate
169
+ """
170
+ speed = config.get('speed', 1.06)
171
+ language = config.get('language', "wo")
172
+ reference_audio = config.get('reference_audio', None)
173
+ return self.generate_audio(text=text, reference_audio=reference_audio, speed=speed, language=language, output_path=output_path)
174
+
175
+
176
+ # Exemple d'utilisation
177
+ if __name__ == "__main__":
178
+ tts = WolofXTTSInference()
179
+
180
+ # Exemple de génération d'audio
181
+ text = "Màngi tuddu Aadama, di baat bii waa Galsen A.I defar ngir wax ak yéen ci wolof!"
182
+
183
+ # Simple
184
+ audio, sr = tts.generate_audio(
185
+ text,
186
+ output_path="generated_audio.wav"
187
+ )
188
+
189
+ # Avec une config
190
+ config_gen_audio = {
191
+ "speed": 1.2,
192
+ "language": "wo",
193
+ }
194
+ audio, sr = tts.generate_audio_from_config(
195
+ text=text,
196
+ config=config_gen_audio,
197
+ output_path="generated_audio_config.wav"
198
+ )
requirements.txt CHANGED
@@ -1,24 +1,73 @@
1
- torch
 
 
 
 
 
 
 
2
  torchaudio
3
- soundfile
4
  transformers
5
- gradio
6
- huggingface_hub
7
- tqdm
8
- coqpit
9
  trainer
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  librosa
11
- torchaudio
12
- einops
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  pypinyin
 
14
  hangul_romanize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  num2words
16
- spacy
17
- mutagen
18
- matplotlib
19
- pyAudioAnalysis
20
- eyed3
21
- hmmlearn
22
- imblearn
23
- plotly
24
- pesq
 
 
 
1
+ gradio
2
+ # core deps
3
+ numpy==1.23.0;python_version<="3.10"
4
+ numpy>=1.24.3;python_version>"3.10"
5
+ matplotlib
6
+ cython>=0.29.30
7
+ scipy>=1.11.2
8
+ torch>=2.1
9
  torchaudio
 
10
  transformers
11
+ gdown
 
 
 
12
  trainer
13
+ soundfile>=0.12.0
14
+ librosa>=0.10.0
15
+ scikit-learn>=1.3.0
16
+ numba==0.55.1;python_version<"3.9"
17
+ numba>=0.57.0;python_version>="3.9"
18
+ inflect>=5.6.0
19
+ tqdm>=4.64.1
20
+ anyascii>=0.3.0
21
+ pyyaml>=6.0
22
+ fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail
23
+ aiohttp>=3.8.1
24
+ packaging>=23.1
25
+ mutagen==1.47.0
26
  librosa
27
+ # deps for examples
28
+ flask>=2.0.1
29
+ # deps for inference
30
+ pysbd>=0.3.4
31
+ # deps for notebooks
32
+ umap-learn>=0.5.1
33
+ pandas>=1.4,<2.0
34
+ # deps for training
35
+ matplotlib>=3.7.0
36
+ # coqui stack
37
+ trainer>=0.0.36
38
+ # config management
39
+ coqpit>=0.0.16
40
+ # chinese g2p deps
41
+ jieba
42
  pypinyin
43
+ # korean
44
  hangul_romanize
45
+ # gruut+supported langs
46
+ gruut[de,es,fr]==2.2.3
47
+ # deps for korean
48
+ jamo
49
+ nltk
50
+ g2pkk>=0.1.1
51
+ # deps for bangla
52
+ bangla
53
+ bnnumerizer
54
+ bnunicodenormalizer
55
+ # deps for tortoise
56
+ einops>=0.6.0
57
+ transformers>=4.45.2
58
+ # deps for bark
59
+ encodec>=0.1.1
60
+ # deps for XTTS
61
+ unidecode>=1.3.2
62
  num2words
63
+ # spacy[ja]>=3
64
+ tokenizers==0.20.1
65
+ vinorm==2.0.7
66
+ underthesea==6.8.4
67
+ # remove silence
68
+ hmmlearn==0.3.3
69
+ eyed3==0.9.7
70
+ pesq==0.0.4
71
+ pydub==0.25.1
72
+ pyAudioAnalysis==0.3.14
73
+ ffmpeg-python==0.2.0