multimodalart HF Staff commited on
Commit
929d1f5
·
1 Parent(s): 1ee5745

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -25
app.py CHANGED
@@ -11,11 +11,6 @@ from diffusers import StableDiffusionImg2ImgPipeline
11
  from share_btn import community_icon_html, loading_icon_html, share_js
12
 
13
  device = "cuda"
14
- MODEL_ID = "riffusion/riffusion-model-v1"
15
- pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
16
- pipe = pipe.to(device)
17
- pipe2 = StableDiffusionImg2ImgPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
18
- pipe2 = pipe2.to(device)
19
 
20
  spectro_from_wav = gr.Interface.load("spaces/fffiloni/audio-to-spectrogram")
21
 
@@ -25,7 +20,10 @@ def predict(prompt, negative_prompt, audio_input, duration):
25
  else :
26
  return style_transfer(prompt, negative_prompt, audio_input)
27
 
28
- def classic(prompt, negative_prompt, duration):
 
 
 
29
  if duration == 5:
30
  width_duration=512
31
  else :
@@ -37,23 +35,6 @@ def classic(prompt, negative_prompt, duration):
37
  f.write(wav[0].getbuffer())
38
  return spec, 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
39
 
40
- def style_transfer(prompt, negative_prompt, audio_input):
41
- spec = spectro_from_wav(audio_input)
42
- print(spec)
43
- # Open the image
44
- im = Image.open(spec)
45
-
46
-
47
- # Open the image
48
- im = image_from_spectrogram(im, 1)
49
-
50
-
51
- new_spectro = pipe2(prompt=prompt, image=im, strength=0.5, guidance_scale=7).images
52
- wav = wav_bytes_from_spectrogram_image(new_spectro[0])
53
- with open("output.wav", "wb") as f:
54
- f.write(wav[0].getbuffer())
55
- return new_spectro[0], 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
56
-
57
  def image_from_spectrogram(
58
  spectrogram: np.ndarray, max_volume: float = 50, power_for_image: float = 0.25
59
  ) -> Image.Image:
@@ -194,7 +175,7 @@ with gr.Blocks(css=css) as demo:
194
  with gr.Column(elem_id="col-container"):
195
 
196
  gr.HTML(title)
197
-
198
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
199
  audio_input = gr.Audio(source="upload", type="filepath", visible=False)
200
  with gr.Row():
@@ -215,7 +196,7 @@ with gr.Blocks(css=css) as demo:
215
 
216
  gr.HTML(article)
217
 
218
- send_btn.click(predict, inputs=[prompt_input, negative_prompt, audio_input, duration_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
219
  share_button.click(None, [], [], _js=share_js)
220
 
221
  demo.queue(max_size=250).launch(debug=True)
 
11
  from share_btn import community_icon_html, loading_icon_html, share_js
12
 
13
  device = "cuda"
 
 
 
 
 
14
 
15
  spectro_from_wav = gr.Interface.load("spaces/fffiloni/audio-to-spectrogram")
16
 
 
20
  else :
21
  return style_transfer(prompt, negative_prompt, audio_input)
22
 
23
+ def classic(model_input, prompt, negative_prompt, duration):
24
+ pipe = StableDiffusionPipeline.from_pretrained(model_input, torch_dtype=torch.float16)
25
+ pipe = pipe.to(device)
26
+
27
  if duration == 5:
28
  width_duration=512
29
  else :
 
35
  f.write(wav[0].getbuffer())
36
  return spec, 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def image_from_spectrogram(
39
  spectrogram: np.ndarray, max_volume: float = 50, power_for_image: float = 0.25
40
  ) -> Image.Image:
 
175
  with gr.Column(elem_id="col-container"):
176
 
177
  gr.HTML(title)
178
+ model_input = gr.Textbox(placeholder="Your Riffusion fine-tuned model Hugging Face ID")
179
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
180
  audio_input = gr.Audio(source="upload", type="filepath", visible=False)
181
  with gr.Row():
 
196
 
197
  gr.HTML(article)
198
 
199
+ send_btn.click(predict, inputs=[model_input, prompt_input, negative_prompt, audio_input, duration_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
200
  share_button.click(None, [], [], _js=share_js)
201
 
202
  demo.queue(max_size=250).launch(debug=True)