Fabrice-TIERCELIN commited on
Commit
dec91cc
·
verified ·
1 Parent(s): 1470851

Handle seed

Browse files
Files changed (1) hide show
  1. app.py +59 -56
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import json
3
  import torch
4
  import time
5
- import wavio
6
 
7
  from tqdm import tqdm
8
  from huggingface_hub import snapshot_download
@@ -11,6 +11,8 @@ from audioldm.audio.stft import TacotronSTFT
11
  from audioldm.variational_autoencoder import AutoencoderKL
12
  from pydub import AudioSegment
13
 
 
 
14
  # Automatic device detection
15
  if torch.cuda.is_available():
16
  device_type = "cuda"
@@ -82,17 +84,21 @@ tango.vae.to(device_type)
82
  tango.stft.to(device_type)
83
  tango.model.to(device_type)
84
 
 
 
 
 
 
85
  def check(
86
  prompt,
87
- output_format,
88
  output_number,
89
  steps,
90
- guidance
 
 
91
  ):
92
  if prompt is None or prompt == "":
93
  raise gr.Error("Please provide a prompt input.")
94
- if not output_format in ["wav", "mp3"]:
95
- raise gr.Error("Please choose an allowed output format (.wav or .mp3).")
96
  if not output_number in [1, 2, 3]:
97
  raise gr.Error("Please ask for 1, 2 or 3 output files.")
98
 
@@ -100,45 +106,31 @@ def update_output(output_format, output_number):
100
  return [
101
  gr.update(format = output_format),
102
  gr.update(format = output_format, visible = (2 <= output_number)),
103
- gr.update(format = output_format, visible = (output_number == 3))
 
104
  ]
105
 
106
  def text2audio(
107
  prompt,
108
- output_format,
109
  output_number,
110
  steps,
111
- guidance
 
 
112
  ):
113
  start = time.time()
114
- output_wave = tango.generate(prompt, steps, guidance, output_number)
115
 
116
- output_filename_1 = "tmp1.wav"
117
- wavio.write(output_filename_1, output_wave[0], rate = 16000, sampwidth = 2)
118
 
119
- if (output_format == "mp3"):
120
- AudioSegment.from_wav("tmp1.wav").export("tmp1.mp3", format = "mp3")
121
- output_filename_1 = "tmp1.mp3"
122
 
123
- if (2 <= output_number):
124
- output_filename_2 = "tmp2.wav"
125
- wavio.write(output_filename_2, output_wave[1], rate = 16000, sampwidth = 2)
126
-
127
- if (output_format == "mp3"):
128
- AudioSegment.from_wav("tmp2.wav").export("tmp2.mp3", format = "mp3")
129
- output_filename_2 = "tmp2.mp3"
130
- else:
131
- output_filename_2 = None
132
-
133
- if (output_number == 3):
134
- output_filename_3 = "tmp3.wav"
135
- wavio.write(output_filename_3, output_wave[2], rate = 16000, sampwidth = 2)
136
 
137
- if (output_format == "mp3"):
138
- AudioSegment.from_wav("tmp3.wav").export("tmp3.mp3", format = "mp3")
139
- output_filename_3 = "tmp3.mp3"
140
- else:
141
- output_filename_3 = None
142
 
143
  end = time.time()
144
  secondes = int(end - start)
@@ -147,10 +139,10 @@ def text2audio(
147
  hours = minutes // 60
148
  minutes = minutes - (hours * 60)
149
  return [
150
- output_filename_1,
151
- output_filename_2,
152
- output_filename_3,
153
- "Start again to get a different result. The output have been generated in " + ((str(hours) + " h, ") if hours != 0 else "") + ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + str(secondes) + " sec."
154
  ]
155
 
156
  # Gradio interface
@@ -168,45 +160,55 @@ with gr.Blocks() as interface:
168
  <li>If you need to generate <b>music</b>, I recommend to use <i>MusicGen</i>,</li>
169
  </ul>
170
  <br/>
171
- 🐌 Slow process... ~2 hours. Your computer must <b><u>not</u></b> enter into standby mode.<br/>You can duplicate this space on a free account, it works on CPU.<br/>
172
- <a href='https://huggingface.co/spaces/Fabrice-TIERCELIN/Text-to-Audio?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14'></a>
173
  <br/>
174
  ⚖️ You can use, modify and share the generated sounds but not for commercial uses.
175
  """
176
  )
177
  input_text = gr.Textbox(label = "Prompt", value = "Snort of a horse", lines = 2, autofocus = True)
178
- output_format = gr.Radio(label = "Output format", info = "The file you can dowload", choices = ["mp3", "wav"], value = "wav")
179
  with gr.Accordion("Advanced options", open = False):
 
180
  output_number = gr.Slider(label = "Number of generations", info = "1, 2 or 3 output files", minimum = 1, maximum = 3, value = 3, step = 1, interactive = True)
181
- denoising_steps = gr.Slider(label = "Steps", info = "lower=faster & variant, higher=audio quality & similar", minimum = 100, maximum = 200, value = 100, step = 1, interactive = True)
182
  guidance_scale = gr.Slider(label = "Guidance Scale", info = "lower=audio quality, higher=follow the prompt", minimum = 1, maximum = 10, value = 3, step = 0.1, interactive = True)
 
 
183
 
184
  submit = gr.Button("🚀 Generate", variant = "primary")
185
 
186
- output_audio_1 = gr.Audio(label = "Generated Audio #1/3", format = "wav", type="filepath", autoplay = True)
187
- output_audio_2 = gr.Audio(label = "Generated Audio #2/3", format = "wav", type="filepath")
188
- output_audio_3 = gr.Audio(label = "Generated Audio #3/3", format = "wav", type="filepath")
189
  information = gr.Label(label = "Information")
190
 
191
- submit.click(fn = check, inputs = [
 
 
 
 
 
192
  input_text,
193
- output_format,
194
  output_number,
195
  denoising_steps,
196
- guidance_scale
 
 
197
  ], outputs = [], queue = False, show_progress = False).success(fn = update_output, inputs = [
198
  output_format,
199
  output_number
200
  ], outputs = [
201
  output_audio_1,
202
  output_audio_2,
203
- output_audio_3
 
204
  ], queue = False, show_progress = False).success(fn = text2audio, inputs = [
205
  input_text,
206
- output_format,
207
  output_number,
208
  denoising_steps,
209
- guidance_scale
 
 
210
  ], outputs = [
211
  output_audio_1,
212
  output_audio_2,
@@ -218,10 +220,11 @@ with gr.Blocks() as interface:
218
  fn = text2audio,
219
  inputs = [
220
  input_text,
221
- output_format,
222
  output_number,
223
  denoising_steps,
224
- guidance_scale
 
 
225
  ],
226
  outputs = [
227
  output_audio_1,
@@ -230,11 +233,11 @@ with gr.Blocks() as interface:
230
  information
231
  ],
232
  examples = [
233
- ["A hammer is hitting a wooden surface", "mp3", 3, 100, 3],
234
- ["Peaceful and calming ambient music with singing bowl and other instruments.", "wav", 3, 100, 3],
235
- ["A man is speaking in a small room.", "mp3", 2, 100, 3],
236
- ["A female is speaking followed by footstep sound", "mp3", 1, 100, 3],
237
- ["Wooden table tapping sound followed by water pouring sound.", "mp3", 3, 200, 3],
238
  ],
239
  cache_examples = "lazy",
240
  )
 
2
  import json
3
  import torch
4
  import time
5
+ import random
6
 
7
  from tqdm import tqdm
8
  from huggingface_hub import snapshot_download
 
11
  from audioldm.variational_autoencoder import AutoencoderKL
12
  from pydub import AudioSegment
13
 
14
+ max_64_bit_int = 2**63 - 1
15
+
16
  # Automatic device detection
17
  if torch.cuda.is_available():
18
  device_type = "cuda"
 
84
  tango.stft.to(device_type)
85
  tango.model.to(device_type)
86
 
87
+ def update_seed(is_randomize_seed, seed):
88
+ if is_randomize_seed:
89
+ return random.randint(0, max_64_bit_int)
90
+ return seed
91
+
92
  def check(
93
  prompt,
 
94
  output_number,
95
  steps,
96
+ guidance,
97
+ is_randomize_seed,
98
+ seed
99
  ):
100
  if prompt is None or prompt == "":
101
  raise gr.Error("Please provide a prompt input.")
 
 
102
  if not output_number in [1, 2, 3]:
103
  raise gr.Error("Please ask for 1, 2 or 3 output files.")
104
 
 
106
  return [
107
  gr.update(format = output_format),
108
  gr.update(format = output_format, visible = (2 <= output_number)),
109
+ gr.update(format = output_format, visible = (output_number == 3)),
110
+ gr.update(visible = False)
111
  ]
112
 
113
  def text2audio(
114
  prompt,
 
115
  output_number,
116
  steps,
117
+ guidance,
118
+ is_randomize_seed,
119
+ seed
120
  ):
121
  start = time.time()
 
122
 
123
+ if seed is None:
124
+ seed = random.randint(0, max_64_bit_int)
125
 
126
+ random.seed(seed)
127
+ torch.manual_seed(seed)
 
128
 
129
+ output_wave = tango.generate(prompt, steps, guidance, output_number)
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ output_wave_1 = gr.make_waveform((16000, output_wave[0]))
132
+ output_wave_2 = gr.make_waveform((16000, output_wave[1])) if (2 <= output_number) else None
133
+ output_wave_3 = gr.make_waveform((16000, output_wave[2])) if (output_number == 3) else None
 
 
134
 
135
  end = time.time()
136
  secondes = int(end - start)
 
139
  hours = minutes // 60
140
  minutes = minutes - (hours * 60)
141
  return [
142
+ output_wave_1,
143
+ output_wave_2,
144
+ output_wave_3,
145
+ gr.update(visible = True, value = "Start again to get a different result. The output have been generated in " + ((str(hours) + " h, ") if hours != 0 else "") + ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + str(secondes) + " sec.")
146
  ]
147
 
148
  # Gradio interface
 
160
  <li>If you need to generate <b>music</b>, I recommend to use <i>MusicGen</i>,</li>
161
  </ul>
162
  <br/>
163
+ 🐌 Slow process... ~5 min. Your computer must <b><u>not</u></b> enter into standby mode.<br/>You can duplicate this space on a free account, it works on CPU.<br/>
164
+ <a href='https://huggingface.co/spaces/Fabrice-TIERCELIN/Text-to-Audio?duplicate=true&hidden=public&hidden=public'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14'></a>
165
  <br/>
166
  ⚖️ You can use, modify and share the generated sounds but not for commercial uses.
167
  """
168
  )
169
  input_text = gr.Textbox(label = "Prompt", value = "Snort of a horse", lines = 2, autofocus = True)
 
170
  with gr.Accordion("Advanced options", open = False):
171
+ output_format = gr.Radio(label = "Output format", info = "The file you can dowload", choices = ["mp3", "wav"], value = "wav")
172
  output_number = gr.Slider(label = "Number of generations", info = "1, 2 or 3 output files", minimum = 1, maximum = 3, value = 3, step = 1, interactive = True)
173
+ denoising_steps = gr.Slider(label = "Steps", info = "lower=faster & variant, higher=audio quality & similar", minimum = 10, maximum = 200, value = 10, step = 1, interactive = True)
174
  guidance_scale = gr.Slider(label = "Guidance Scale", info = "lower=audio quality, higher=follow the prompt", minimum = 1, maximum = 10, value = 3, step = 0.1, interactive = True)
175
+ randomize_seed = gr.Checkbox(label = "\U0001F3B2 Randomize seed", value = True, info = "If checked, result is always different")
176
+ seed = gr.Slider(minimum = 0, maximum = max_64_bit_int, step = 1, randomize = True, label = "Seed")
177
 
178
  submit = gr.Button("🚀 Generate", variant = "primary")
179
 
180
+ output_audio_1 = gr.Audio(label = "Generated Audio #1/3", format = "wav", type="numpy", autoplay = True)
181
+ output_audio_2 = gr.Audio(label = "Generated Audio #2/3", format = "wav", type="numpy")
182
+ output_audio_3 = gr.Audio(label = "Generated Audio #3/3", format = "wav", type="numpy")
183
  information = gr.Label(label = "Information")
184
 
185
+ submit.click(fn = update_seed, inputs = [
186
+ randomize_seed,
187
+ seed
188
+ ], outputs = [
189
+ seed
190
+ ], queue = False, show_progress = False).then(fn = check, inputs = [
191
  input_text,
 
192
  output_number,
193
  denoising_steps,
194
+ guidance_scale,
195
+ randomize_seed,
196
+ seed
197
  ], outputs = [], queue = False, show_progress = False).success(fn = update_output, inputs = [
198
  output_format,
199
  output_number
200
  ], outputs = [
201
  output_audio_1,
202
  output_audio_2,
203
+ output_audio_3,
204
+ information
205
  ], queue = False, show_progress = False).success(fn = text2audio, inputs = [
206
  input_text,
 
207
  output_number,
208
  denoising_steps,
209
+ guidance_scale,
210
+ randomize_seed,
211
+ seed
212
  ], outputs = [
213
  output_audio_1,
214
  output_audio_2,
 
220
  fn = text2audio,
221
  inputs = [
222
  input_text,
 
223
  output_number,
224
  denoising_steps,
225
+ guidance_scale,
226
+ randomize_seed,
227
+ seed
228
  ],
229
  outputs = [
230
  output_audio_1,
 
233
  information
234
  ],
235
  examples = [
236
+ ["A hammer is hitting a wooden surface", 3, 100, 3, False, 123],
237
+ ["Peaceful and calming ambient music with singing bowl and other instruments.", 3, 100, 3, False, 123],
238
+ ["A man is speaking in a small room.", 2, 100, 3, False, 123],
239
+ ["A female is speaking followed by footstep sound", 1, 100, 3, False, 123],
240
+ ["Wooden table tapping sound followed by water pouring sound.", 3, 200, 3, False, 123],
241
  ],
242
  cache_examples = "lazy",
243
  )