seawolf2357 commited on
Commit
382d3da
·
verified ·
1 Parent(s): 062c717

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +374 -0
app.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import json
4
+ import torch
5
+ import wavio
6
+ from tqdm import tqdm
7
+ from huggingface_hub import snapshot_download
8
+ from models import AudioDiffusion, DDPMScheduler
9
+ from audioldm.audio.stft import TacotronSTFT
10
+ from audioldm.variational_autoencoder import AutoencoderKL
11
+ from pydub import AudioSegment
12
+ from gradio import Markdown
13
+
14
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
15
+ from diffusers import DiffusionPipeline, AudioPipelineOutput
16
+ from transformers import T5EncoderModel, T5Tokenizer, T5TokenizerFast, pipeline
17
+ from typing import Union
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from tqdm import tqdm
20
+ from langdetect import detect, DetectorFactory
21
+
22
+ # Ensure consistent results from langdetect
23
+ DetectorFactory.seed = 0
24
+
25
+ class Tango2Pipeline(DiffusionPipeline):
26
+
27
+ def __init__(
28
+ self,
29
+ vae: AutoencoderKL,
30
+ text_encoder: T5EncoderModel,
31
+ tokenizer: Union[T5Tokenizer, T5TokenizerFast],
32
+ unet: UNet2DConditionModel,
33
+ scheduler: DDPMScheduler
34
+ ):
35
+ super().__init__()
36
+ self.register_modules(
37
+ vae=vae,
38
+ text_encoder=text_encoder,
39
+ tokenizer=tokenizer,
40
+ unet=unet,
41
+ scheduler=scheduler
42
+ )
43
+
44
+ def _encode_prompt(self, prompt):
45
+ device = self.text_encoder.device
46
+
47
+ batch = self.tokenizer(
48
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
49
+ )
50
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
51
+
52
+ encoder_hidden_states = self.text_encoder(
53
+ input_ids=input_ids, attention_mask=attention_mask
54
+ )[0]
55
+
56
+ boolean_encoder_mask = (attention_mask == 1).to(device)
57
+
58
+ return encoder_hidden_states, boolean_encoder_mask
59
+
60
+ def _encode_text_classifier_free(self, prompt, num_samples_per_prompt):
61
+ device = self.text_encoder.device
62
+ batch = self.tokenizer(
63
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
64
+ )
65
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
66
+
67
+ with torch.no_grad():
68
+ prompt_embeds = self.text_encoder(
69
+ input_ids=input_ids, attention_mask=attention_mask
70
+ )[0]
71
+
72
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
73
+ attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
74
+
75
+ # get unconditional embeddings for classifier free guidance
76
+ uncond_tokens = [""] * len(prompt)
77
+
78
+ max_length = prompt_embeds.shape[1]
79
+ uncond_batch = self.tokenizer(
80
+ uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt",
81
+ )
82
+ uncond_input_ids = uncond_batch.input_ids.to(device)
83
+ uncond_attention_mask = uncond_batch.attention_mask.to(device)
84
+
85
+ with torch.no_grad():
86
+ negative_prompt_embeds = self.text_encoder(
87
+ input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
88
+ )[0]
89
+
90
+ negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
91
+ uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
92
+
93
+ # For classifier free guidance, we need to do two forward passes.
94
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
95
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
96
+ prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
97
+ boolean_prompt_mask = (prompt_mask == 1).to(device)
98
+
99
+ return prompt_embeds, boolean_prompt_mask
100
+
101
+ def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device):
102
+ shape = (batch_size, num_channels_latents, 256, 16)
103
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
104
+ # scale the initial noise by the standard deviation required by the scheduler
105
+ latents = latents * inference_scheduler.init_noise_sigma
106
+ return latents
107
+
108
+ @torch.no_grad()
109
+ def inference(self, prompt, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1,
110
+ disable_progress=True):
111
+ device = self.text_encoder.device
112
+ classifier_free_guidance = guidance_scale > 1.0
113
+ batch_size = len(prompt) * num_samples_per_prompt
114
+
115
+ if classifier_free_guidance:
116
+ prompt_embeds, boolean_prompt_mask = self._encode_text_classifier_free(prompt, num_samples_per_prompt)
117
+ else:
118
+ prompt_embeds, boolean_prompt_mask = self._encode_prompt(prompt)
119
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
120
+ boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0)
121
+
122
+ inference_scheduler.set_timesteps(num_steps, device=device)
123
+ timesteps = inference_scheduler.timesteps
124
+
125
+ num_channels_latents = self.unet.config.in_channels
126
+ latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device)
127
+
128
+ num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order
129
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
130
+
131
+ for i, t in enumerate(timesteps):
132
+ # expand the latents if we are doing classifier free guidance
133
+ latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
134
+ latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t)
135
+
136
+ noise_pred = self.unet(
137
+ latent_model_input, t, encoder_hidden_states=prompt_embeds,
138
+ encoder_attention_mask=boolean_prompt_mask
139
+ ).sample
140
+
141
+ # perform guidance
142
+ if classifier_free_guidance:
143
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
144
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
145
+
146
+ # compute the previous noisy sample x_t -> x_t-1
147
+ latents = inference_scheduler.step(noise_pred, t, latents).prev_sample
148
+
149
+ # call the callback, if provided
150
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0):
151
+ progress_bar.update(1)
152
+
153
+ return latents
154
+
155
+ @torch.no_grad()
156
+ def __call__(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
157
+ """ Generate audio for a single prompt string. """
158
+ with torch.no_grad():
159
+ latents = self.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
160
+ mel = self.vae.decode_first_stage(latents)
161
+ wave = self.vae.decode_to_waveform(mel)
162
+
163
+ return AudioPipelineOutput(audios=wave)
164
+
165
+ # Automatic device detection
166
+ if torch.cuda.is_available():
167
+ device_type = "cuda"
168
+ device_selection = "cuda:0"
169
+ else:
170
+ device_type = "cpu"
171
+ device_selection = "cpu"
172
+
173
+ class Tango:
174
+ def __init__(self, name="declare-lab/tango2", device=device_selection):
175
+
176
+ path = snapshot_download(repo_id=name)
177
+
178
+ vae_config = json.load(open("{}/vae_config.json".format(path)))
179
+ stft_config = json.load(open("{}/stft_config.json".format(path)))
180
+ main_config = json.load(open("{}/main_config.json".format(path)))
181
+
182
+ self.vae = AutoencoderKL(**vae_config).to(device)
183
+ self.stft = TacotronSTFT(**stft_config).to(device)
184
+ self.model = AudioDiffusion(**main_config).to(device)
185
+
186
+ vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location=device)
187
+ stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location=device)
188
+ main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location=device)
189
+
190
+ self.vae.load_state_dict(vae_weights)
191
+ self.stft.load_state_dict(stft_weights)
192
+ self.model.load_state_dict(main_weights)
193
+
194
+ print ("Successfully loaded checkpoint from:", name)
195
+
196
+ self.vae.eval()
197
+ self.stft.eval()
198
+ self.model.eval()
199
+
200
+ self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")
201
+
202
+ def chunks(self, lst, n):
203
+ """ Yield successive n-sized chunks from a list. """
204
+ for i in range(0, len(lst), n):
205
+ yield lst[i:i + n]
206
+
207
+ def generate(self, prompt, steps=200, guidance=8, samples=1, disable_progress=True):
208
+ """ Generate audio for a single prompt string. """
209
+ with torch.no_grad():
210
+ latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
211
+ mel = self.vae.decode_first_stage(latents)
212
+ wave = self.vae.decode_to_waveform(mel)
213
+ return wave[0]
214
+
215
+ def generate_for_batch(self, prompts, steps=200, guidance=8, samples=1, batch_size=8, disable_progress=True):
216
+ """ Generate audio for a list of prompt strings. """
217
+ outputs = []
218
+ for k in tqdm(range(0, len(prompts), batch_size)):
219
+ batch = prompts[k: k+batch_size]
220
+ with torch.no_grad():
221
+ latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
222
+ mel = self.vae.decode_first_stage(latents)
223
+ wave = self.vae.decode_to_waveform(mel)
224
+ outputs += [item for item in wave]
225
+ if samples == 1:
226
+ return outputs
227
+ else:
228
+ return list(self.chunks(outputs, samples))
229
+
230
+ # Initialize TANGO
231
+ tango = Tango(device=device_selection)
232
+ tango.vae.to(device_type)
233
+ tango.stft.to(device_type)
234
+ tango.model.to(device_type)
235
+
236
+ pipe = Tango2Pipeline(
237
+ vae=tango.vae,
238
+ text_encoder=tango.model.text_encoder,
239
+ tokenizer=tango.model.tokenizer,
240
+ unet=tango.model.unet,
241
+ scheduler=tango.scheduler
242
+ )
243
+
244
+ # Initialize Translation Pipeline
245
+ translation_pipeline = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
246
+
247
+ def adjust_audio_length(audio_path, desired_length_sec, output_format):
248
+ """
249
+ Adjust the audio to the desired length.
250
+ If the audio is shorter, pad with silence.
251
+ If longer, trim the audio.
252
+ """
253
+ audio = AudioSegment.from_file(audio_path)
254
+ desired_length_ms = desired_length_sec * 1000 # Convert to milliseconds
255
+
256
+ if len(audio) < desired_length_ms:
257
+ # Pad with silence
258
+ padding = AudioSegment.silent(duration=desired_length_ms - len(audio))
259
+ audio += padding
260
+ elif len(audio) > desired_length_ms:
261
+ # Trim the audio
262
+ audio = audio[:desired_length_ms]
263
+
264
+ # Export the adjusted audio
265
+ adjusted_path = f"adjusted.{output_format}"
266
+ audio.export(adjusted_path, format=output_format)
267
+ return adjusted_path
268
+
269
+ @spaces.GPU(duration=60)
270
+ def gradio_generate(prompt, output_format, steps, guidance, audio_length):
271
+ """
272
+ Generate audio based on the prompt, translate if necessary, and adjust its length.
273
+ """
274
+ # Detect language
275
+ try:
276
+ lang = detect(prompt)
277
+ except:
278
+ lang = "unknown"
279
+
280
+ # If the prompt is in Korean, translate to English
281
+ if lang == "ko":
282
+ translated = translation_pipeline(prompt)[0]['translation_text']
283
+ print(f"Translated Prompt: {translated}")
284
+ prompt_to_use = translated
285
+ else:
286
+ prompt_to_use = prompt
287
+
288
+ # Generate audio using the pipeline
289
+ output_wave = pipe(prompt_to_use, steps, guidance)
290
+ output_wave = output_wave.audios[0]
291
+ temp_wav = "temp.wav"
292
+ wavio.write(temp_wav, output_wave, rate=16000, sampwidth=2)
293
+
294
+ # Adjust audio length
295
+ adjusted_path = adjust_audio_length(temp_wav, audio_length, output_format)
296
+
297
+ return adjusted_path
298
+
299
+ # Gradio input and output components
300
+ input_text = gr.Textbox(lines=2, label="Prompt")
301
+ output_format = gr.Radio(
302
+ label="Output Format",
303
+ info="The file you can download",
304
+ choices=["mp3", "wav"],
305
+ value="wav"
306
+ )
307
+ audio_length = gr.Slider(
308
+ minimum=4,
309
+ maximum=10,
310
+ step=1,
311
+ label="Audio Length (seconds)",
312
+ value=6,
313
+ interactive=True
314
+ )
315
+ output_audio = gr.Audio(label="Generated Audio", type="filepath")
316
+ denoising_steps = gr.Slider(
317
+ minimum=100,
318
+ maximum=200,
319
+ step=1,
320
+ label="Steps",
321
+ value=200, # Changed from 100 to 200
322
+ interactive=True
323
+ )
324
+ guidance_scale = gr.Slider(
325
+ minimum=1,
326
+ maximum=10,
327
+ step=0.1,
328
+ label="Guidance Scale",
329
+ value=8, # Changed from 3 to 8
330
+ interactive=True
331
+ )
332
+
333
+ # Gradio interface
334
+ gr_interface = gr.Interface(
335
+ theme="Nymbo/Nymbo_Theme",
336
+ fn=gradio_generate,
337
+ inputs=[input_text, output_format, denoising_steps, guidance_scale, audio_length],
338
+ outputs=[output_audio],
339
+ title="T2: Text to SoundFX",
340
+ allow_flagging=False,
341
+ examples=[
342
+ ["조용한 말소리 후 비행기가 멀어지는 소리"],
343
+ ["사람들이 환호하고 박수치는 소리"],
344
+ ["강한 바람 소리와 빗소리"],
345
+ ["Quiet speech and then and airplane flying away"],
346
+ ["A bicycle peddling on dirt and gravel followed by a man speaking then laughing"],
347
+ ["Ducks quack and water splashes with some animal screeching in the background"],
348
+ ["Describe the sound of the ocean"],
349
+ ["A woman and a baby are having a conversation"],
350
+ ["A man speaks followed by a popping noise and laughter"],
351
+ ["A cup is filled from a faucet"],
352
+ ["An audience cheering and clapping"],
353
+ ["Rolling thunder with lightning strikes"],
354
+ ["A dog barking and a cat mewing and a racing car passes by"],
355
+ ["Gentle water stream, birds chirping and sudden gun shot"],
356
+ ["A man talking followed by a goat baaing then a metal gate sliding shut as ducks quack and wind blows into a microphone."],
357
+ ["A dog barking"],
358
+ ["A cat meowing"],
359
+ ["Wooden table tapping sound while water pouring"],
360
+ ["Applause from a crowd with distant clicking and a man speaking over a loudspeaker"],
361
+ ["two gunshots followed by birds flying away while chirping"],
362
+ ["Whistling with birds chirping"],
363
+ ["A person snoring"],
364
+ ["Motor vehicles are driving with loud engines and a person whistles"],
365
+ ["People cheering in a stadium while thunder and lightning strikes"],
366
+ ["A helicopter is in flight"],
367
+ ["A dog barking and a man talking and a racing car passes by"],
368
+
369
+ ],
370
+ cache_examples="lazy", # Turn on to cache.
371
+ )
372
+
373
+ # Launch Gradio app
374
+ gr_interface.queue(10).launch()