Commit
fbf5d25
1 Parent(s): 8b43f70

1, 2 or 3 output files (#3)

Browse files

- 1, 2 or 3 output files (8ad7a258d2c589a1ecaf81b0161a200c278548c9)


Co-authored-by: Fabrice TIERCELIN <[email protected]>

Files changed (1) hide show
  1. app.py +35 -16
app.py CHANGED
@@ -49,12 +49,12 @@ class Tango:
49
  self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")
50
 
51
  def chunks(self, lst, n):
52
- """ Yield successive n-sized chunks from a list. """
53
  for i in range(0, len(lst), n):
54
  yield lst[i:i + n]
55
 
56
  def generate(self, prompt, steps=100, guidance=3, samples=3, disable_progress=True):
57
- """ Genrate audio for a single prompt string. """
58
  with torch.no_grad():
59
  latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
60
  mel = self.vae.decode_first_stage(latents)
@@ -62,7 +62,7 @@ class Tango:
62
  return wave
63
 
64
  def generate_for_batch(self, prompts, steps=200, guidance=3, samples=1, batch_size=8, disable_progress=True):
65
- """ Genrate audio for a list of prompt strings. """
66
  outputs = []
67
  for k in tqdm(range(0, len(prompts), batch_size)):
68
  batch = prompts[k: k+batch_size]
@@ -84,24 +84,42 @@ tango.stft.to(device_type)
84
  tango.model.to(device_type)
85
 
86
  @spaces.GPU(duration=120)
87
- def gradio_generate(prompt, output_format, steps, guidance):
88
- output_wave = tango.generate(prompt, steps, guidance)
 
 
 
 
 
 
89
  # output_filename = f"{prompt.replace(' ', '_')}_{steps}_{guidance}"[:250] + ".wav"
90
 
91
  output_filename_1 = "tmp1.wav"
92
- wavio.write(output_filename_1, output_wave[0], rate=16000, sampwidth=2)
93
- output_filename_2 = "tmp2.wav"
94
- wavio.write(output_filename_2, output_wave[1], rate=16000, sampwidth=2)
95
- output_filename_3 = "tmp3.wav"
96
- wavio.write(output_filename_3, output_wave[2], rate=16000, sampwidth=2)
97
 
98
  if (output_format == "mp3"):
99
  AudioSegment.from_wav("tmp1.wav").export("tmp1.mp3", format = "mp3")
100
  output_filename_1 = "tmp1.mp3"
101
- AudioSegment.from_wav("tmp2.wav").export("tmp2.mp3", format = "mp3")
102
- output_filename_2 = "tmp2.mp3"
103
- AudioSegment.from_wav("tmp3.wav").export("tmp3.mp3", format = "mp3")
104
- output_filename_3 = "tmp3.mp3"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  return [output_filename_1, output_filename_2, output_filename_3]
107
 
@@ -133,16 +151,17 @@ Generate audio using Tango2 by providing a text prompt. Tango2 was built from Ta
133
  # Gradio input and output components
134
  input_text = gr.Textbox(lines=2, label="Prompt")
135
  output_format = gr.Radio(label = "Output format", info = "The file you can download", choices = ["mp3", "wav"], value = "wav")
 
136
  output_audio_1 = gr.Audio(label="Generated Audio #1/3", type="filepath")
137
  output_audio_2 = gr.Audio(label="Generated Audio #2/3", type="filepath")
138
  output_audio_3 = gr.Audio(label="Generated Audio #3/3", type="filepath")
139
- denoising_steps = gr.Slider(minimum=100, maximum=200, value=100, step=1, label="Steps", interactive=True)
140
  guidance_scale = gr.Slider(minimum=1, maximum=10, value=3, step=0.1, label="Guidance Scale", interactive=True)
141
 
142
  # Gradio interface
143
  gr_interface = gr.Interface(
144
  fn=gradio_generate,
145
- inputs=[input_text, output_format, denoising_steps, guidance_scale],
146
  outputs=[output_audio_1, output_audio_2, output_audio_3],
147
  title="Tango 2: Aligning Diffusion-based Text-to-Audio Generations through Direct Preference Optimization",
148
  description=description_text,
 
49
  self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")
50
 
51
  def chunks(self, lst, n):
52
+ # Yield successive n-sized chunks from a list
53
  for i in range(0, len(lst), n):
54
  yield lst[i:i + n]
55
 
56
  def generate(self, prompt, steps=100, guidance=3, samples=3, disable_progress=True):
57
+ # Genrate audio for a single prompt string
58
  with torch.no_grad():
59
  latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
60
  mel = self.vae.decode_first_stage(latents)
 
62
  return wave
63
 
64
  def generate_for_batch(self, prompts, steps=200, guidance=3, samples=1, batch_size=8, disable_progress=True):
65
+ # Genrate audio for a list of prompt strings
66
  outputs = []
67
  for k in tqdm(range(0, len(prompts), batch_size)):
68
  batch = prompts[k: k+batch_size]
 
84
  tango.model.to(device_type)
85
 
86
  @spaces.GPU(duration=120)
87
+ def gradio_generate(
88
+ prompt,
89
+ output_format,
90
+ output_number,
91
+ steps,
92
+ guidance
93
+ ):
94
+ output_wave = tango.generate(prompt, steps, guidance, output_number)
95
  # output_filename = f"{prompt.replace(' ', '_')}_{steps}_{guidance}"[:250] + ".wav"
96
 
97
  output_filename_1 = "tmp1.wav"
98
+ wavio.write(output_filename_1, output_wave[0], rate = 16000, sampwidth = 2)
 
 
 
 
99
 
100
  if (output_format == "mp3"):
101
  AudioSegment.from_wav("tmp1.wav").export("tmp1.mp3", format = "mp3")
102
  output_filename_1 = "tmp1.mp3"
103
+
104
+ if (2 <= output_number):
105
+ output_filename_2 = "tmp2.wav"
106
+ wavio.write(output_filename_2, output_wave[1], rate = 16000, sampwidth = 2)
107
+
108
+ if (output_format == "mp3"):
109
+ AudioSegment.from_wav("tmp2.wav").export("tmp2.mp3", format = "mp3")
110
+ output_filename_2 = "tmp2.mp3"
111
+ else:
112
+ output_filename_2 = None
113
+
114
+ if (output_number == 3):
115
+ output_filename_3 = "tmp3.wav"
116
+ wavio.write(output_filename_3, output_wave[2], rate = 16000, sampwidth = 2)
117
+
118
+ if (output_format == "mp3"):
119
+ AudioSegment.from_wav("tmp3.wav").export("tmp3.mp3", format = "mp3")
120
+ output_filename_3 = "tmp3.mp3"
121
+ else:
122
+ output_filename_3 = None
123
 
124
  return [output_filename_1, output_filename_2, output_filename_3]
125
 
 
151
  # Gradio input and output components
152
  input_text = gr.Textbox(lines=2, label="Prompt")
153
  output_format = gr.Radio(label = "Output format", info = "The file you can download", choices = ["mp3", "wav"], value = "wav")
154
+ output_number = gr.Slider(label = "Number of generations", info = "1, 2 or 3 output files", minimum = 1, maximum = 3, value = 3, step = 1, interactive = True)
155
  output_audio_1 = gr.Audio(label="Generated Audio #1/3", type="filepath")
156
  output_audio_2 = gr.Audio(label="Generated Audio #2/3", type="filepath")
157
  output_audio_3 = gr.Audio(label="Generated Audio #3/3", type="filepath")
158
+ denoising_steps = gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Steps", interactive=True)
159
  guidance_scale = gr.Slider(minimum=1, maximum=10, value=3, step=0.1, label="Guidance Scale", interactive=True)
160
 
161
  # Gradio interface
162
  gr_interface = gr.Interface(
163
  fn=gradio_generate,
164
+ inputs=[input_text, output_format, output_number, denoising_steps, guidance_scale],
165
  outputs=[output_audio_1, output_audio_2, output_audio_3],
166
  title="Tango 2: Aligning Diffusion-based Text-to-Audio Generations through Direct Preference Optimization",
167
  description=description_text,