Uhhy commited on
Commit
12802b8
verified
1 Parent(s): 24e9d34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -7
app.py CHANGED
@@ -7,14 +7,19 @@ import uuid
7
  from io import StringIO
8
 
9
  import gradio as gr
 
 
 
10
  import spaces
11
  import torch
12
  import torchaudio
13
  from huggingface_hub import HfApi, hf_hub_download, snapshot_download
 
14
  from TTS.tts.configs.xtts_config import XttsConfig
15
  from TTS.tts.models.xtts import Xtts
16
  from vinorm import TTSnorm
17
 
 
18
  os.system("python -m unidic download")
19
 
20
  HF_TOKEN = None
@@ -54,7 +59,7 @@ supported_languages = config.languages
54
  if not "vi" in supported_languages:
55
  supported_languages.append("vi")
56
  if not "es-AR" in supported_languages:
57
- supported_languages.append("es-AR")
58
 
59
  def normalize_vietnamese_text(text):
60
  text = (
@@ -85,6 +90,14 @@ def calculate_keep_len(text, lang):
85
  return 13000 * word_count + 2000 * num_punct
86
  return -1
87
 
 
 
 
 
 
 
 
 
88
 
89
  @spaces.GPU(duration=0)
90
  def predict(
@@ -138,6 +151,11 @@ def predict(
138
  if normalize_text and language == "vi":
139
  prompt = normalize_vietnamese_text(prompt)
140
 
 
 
 
 
 
141
  t0 = time.time()
142
  out = MODEL.inference(
143
  prompt,
@@ -145,7 +163,7 @@ def predict(
145
  gpt_cond_latent,
146
  speaker_embedding,
147
  repetition_penalty=5.0,
148
- temperature=0.75,
149
  enable_text_splitting=True,
150
  )
151
  inference_time = time.time() - t0
@@ -158,7 +176,11 @@ def predict(
158
  keep_len = calculate_keep_len(prompt, language)
159
  out["wav"] = out["wav"][:keep_len]
160
 
161
- torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
 
 
 
 
162
 
163
  except RuntimeError as e:
164
  if "device-side assert" in str(e):
@@ -230,7 +252,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
230
  language_gr = gr.Dropdown(
231
  label="Idioma",
232
  choices=[
233
- "es-AR",
234
  "vi",
235
  "en",
236
  "es",
@@ -251,15 +273,15 @@ with gr.Blocks(analytics_enabled=False) as demo:
251
  "hi",
252
  ],
253
  max_choices=1,
254
- value="es-AR",
255
  )
256
  normalize_text = gr.Checkbox(
257
- label="Normalizar texto en vietnamita",
258
  info="Solo aplicable al idioma vietnamita",
259
  value=True,
260
  )
261
  ref_gr = gr.Audio(
262
- label="Audio de referencia (opcional)",
263
  type="filepath",
264
  value="model/samples/nu-luu-loat.wav",
265
  )
 
7
  from io import StringIO
8
 
9
  import gradio as gr
10
+ import nltk
11
+ import numpy as np
12
+ import pyrubberband
13
  import spaces
14
  import torch
15
  import torchaudio
16
  from huggingface_hub import HfApi, hf_hub_download, snapshot_download
17
+ from nltk.sentiment import SentimentIntensityAnalyzer
18
  from TTS.tts.configs.xtts_config import XttsConfig
19
  from TTS.tts.models.xtts import Xtts
20
  from vinorm import TTSnorm
21
 
22
+ nltk.download('vader_lexicon')
23
  os.system("python -m unidic download")
24
 
25
  HF_TOKEN = None
 
59
  if not "vi" in supported_languages:
60
  supported_languages.append("vi")
61
  if not "es-AR" in supported_languages:
62
+ supported_languages.append("es-AR")
63
 
64
  def normalize_vietnamese_text(text):
65
  text = (
 
90
  return 13000 * word_count + 2000 * num_punct
91
  return -1
92
 
93
+ def analyze_sentiment(text):
94
+ sia = SentimentIntensityAnalyzer()
95
+ scores = sia.polarity_scores(text)
96
+ return scores['compound']
97
+
98
+ def change_pitch(audio_data, sampling_rate, sentiment):
99
+ semitones = sentiment * 2
100
+ return pyrubberband.pitch_shift(audio_data, sampling_rate, semitones)
101
 
102
  @spaces.GPU(duration=0)
103
  def predict(
 
151
  if normalize_text and language == "vi":
152
  prompt = normalize_vietnamese_text(prompt)
153
 
154
+ sentiment = analyze_sentiment(prompt)
155
+
156
+ temperature = 0.75 + sentiment * 0.2
157
+ temperature = max(0.5, min(temperature, 1.0))
158
+
159
  t0 = time.time()
160
  out = MODEL.inference(
161
  prompt,
 
163
  gpt_cond_latent,
164
  speaker_embedding,
165
  repetition_penalty=5.0,
166
+ temperature=temperature,
167
  enable_text_splitting=True,
168
  )
169
  inference_time = time.time() - t0
 
176
  keep_len = calculate_keep_len(prompt, language)
177
  out["wav"] = out["wav"][:keep_len]
178
 
179
+ audio_data = np.array(out["wav"])
180
+
181
+ modified_audio = change_pitch(audio_data, 24000, sentiment)
182
+
183
+ torchaudio.save("output.wav", torch.tensor(modified_audio).unsqueeze(0), 24000)
184
 
185
  except RuntimeError as e:
186
  if "device-side assert" in str(e):
 
252
  language_gr = gr.Dropdown(
253
  label="Idioma",
254
  choices=[
255
+ "es-AR",
256
  "vi",
257
  "en",
258
  "es",
 
273
  "hi",
274
  ],
275
  max_choices=1,
276
+ value="es-AR",
277
  )
278
  normalize_text = gr.Checkbox(
279
+ label="Normalizar texto en vietnamita",
280
  info="Solo aplicable al idioma vietnamita",
281
  value=True,
282
  )
283
  ref_gr = gr.Audio(
284
+ label="Audio de referencia (opcional)",
285
  type="filepath",
286
  value="model/samples/nu-luu-loat.wav",
287
  )