nightfury commited on
Commit
04646cc
·
1 Parent(s): cdac4e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+
5
+ from spectro import wav_bytes_from_spectrogram_image
6
+ from diffusers import StableDiffusionPipeline
7
+
8
+ from transformers import BlipForConditionalGeneration, BlipProcessor
9
+
10
+ from share_btn import community_icon_html, loading_icon_html, share_js
11
+
12
+ model_id = "riffusion/riffusion-model-v1"
13
+ blip_model_id = "Salesforce/blip-image-captioning-base"
14
+ pipe = StableDiffusionPipeline.from_pretrained(model_id)
15
+ pipe = pipe.to("cuda")
16
+
17
+ blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_id, torch_dtype=torch.float16).to("cuda")
18
+ processor = BlipProcessor.from_pretrained(blip_model_id)
19
+
20
+ def predict(image):
21
+ inputs = processor(image, return_tensors="pt").to("cuda", torch.float16)
22
+ output_blip = blip_model.generate(**inputs)
23
+ prompt = processor.decode(output_blip[0], skip_special_tokens=True)
24
+
25
+ spec = pipe(prompt).images[0]
26
+ print(spec)
27
+ wav = wav_bytes_from_spectrogram_image(spec)
28
+ with open("output.wav", "wb") as f:
29
+ f.write(wav[0].getbuffer())
30
+ return spec, 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
31
+
32
+ title = """
33
+ <div style="text-align: center; max-width: 500px; margin: 0 auto;">
34
+ <div
35
+ style="
36
+ display: inline-flex;
37
+ align-items: center;
38
+ gap: 0.8rem;
39
+ font-size: 1.75rem;
40
+ margin-bottom: 10px;
41
+ "
42
+ >
43
+ <h1 style="font-weight: 600; margin-bottom: 7px;">
44
+ Riffusion real-time prompt to image and to music generation system
45
+ </h1>
46
+ </div>
47
+ <p style="margin-bottom: 10px;font-size: 94%;font-weight: 100;line-height: 1.5em;">
48
+ Describe a musical prompt and generate a respective spectrogram image & musical sound associated with.
49
+ </div>
50
+ """
51
+
52
+ article = """
53
+ <p style="font-size: 0.8em;line-height: 1.2em;border: 1px solid #374151;border-radius: 8px;padding: 20px;">
54
+ About the model: Riffusion is a latent text2img diffusion model capable of generating spectrogram images from a given text input prompts. These generated spectrograms are again then utilised to get converted into audio clips.
55
+ <br />—
56
+ <br />The Riffusion model was created by fine-tuning the Stable-Diffusion-v1-5 checkpoint.
57
+ <br />—
58
+ <br />The model is intended for research purposes only. Possible research areas and tasks include
59
+ generation of artworks, audio, and use in creative processes, applications in educational or creative tools, research on generative models.
60
+ </p>
61
+ <div class="footer">
62
+ <p>
63
+ <a href="https://huggingface.co/riffusion/riffusion-model-v1" target="_blank">Riffusion model</a> by Seth Forsgren and Hayk Martiros -
64
+ <a href="https://github.com/salesforce/BLIP" target="_blank"> BLIP Model </a> by Junnan Li et al. - Demo forked from 🤗 <a href="https://huggingface.co/nightfury" target="_blank">Nightfury</a>'s demo
65
+ </p>
66
+ </div>
67
+ """
68
+
69
+ css = '''
70
+ #col-container, #col-container-2 {max-width: 510px; margin-left: auto; margin-right: auto;}
71
+ a {text-decoration-line: underline; font-weight: 600;}
72
+ div#record_btn > .mt-6 {
73
+ margin-top: 0!important;
74
+ }
75
+ div#record_btn > .mt-6 button {
76
+ width: 100%;
77
+ height: 40px;
78
+ }
79
+ .footer {
80
+ margin-bottom: 45px;
81
+ margin-top: 10px;
82
+ text-align: center;
83
+ border-bottom: 1px solid #e5e5e5;
84
+ }
85
+ .footer>p {
86
+ font-size: .8rem;
87
+ display: inline-block;
88
+ padding: 0 10px;
89
+ transform: translateY(10px);
90
+ background: white;
91
+ }
92
+ .dark .footer {
93
+ border-color: #303030;
94
+ }
95
+ .dark .footer>p {
96
+ background: #0b0f19;
97
+ }
98
+ .animate-spin {
99
+ animation: spin 1s linear infinite;
100
+ }
101
+ @keyframes spin {
102
+ from {
103
+ transform: rotate(0deg);
104
+ }
105
+ to {
106
+ transform: rotate(360deg);
107
+ }
108
+ }
109
+ #share-btn-container {
110
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
111
+ }
112
+ #share-btn {
113
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
114
+ }
115
+ #share-btn * {
116
+ all: unset;
117
+ }
118
+ #share-btn-container div:nth-child(-n+2){
119
+ width: auto !important;
120
+ min-height: 0px !important;
121
+ }
122
+ #share-btn-container .wrap {
123
+ display: none !important;
124
+ }
125
+ '''
126
+
127
+
128
+
129
+ with gr.Blocks(css=css) as demo:
130
+
131
+ with gr.Column(elem_id="col-container"):
132
+
133
+ gr.HTML(title)
134
+
135
+ # prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
136
+ image_input = gr.Image()
137
+ send_btn = gr.Button(value="Get a new riffusion spectrogram ! ", elem_id="submit-btn")
138
+
139
+ with gr.Column(elem_id="col-container-2"):
140
+
141
+ spectrogram_output = gr.Image(label="riffusion spectrogram image result", elem_id="img-out")
142
+ sound_output = gr.Audio(type='filepath', label="riffusion spectrogram sound", elem_id="music-out")
143
+
144
+ with gr.Group(elem_id="share-btn-container"):
145
+ community_icon = gr.HTML(community_icon_html, visible=False)
146
+ loading_icon = gr.HTML(loading_icon_html, visible=False)
147
+ share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
148
+
149
+ gr.HTML(article)
150
+
151
+ send_btn.click(predict, inputs=[image_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
152
+ share_button.click(None, [], [], _js=share_js)
153
+
154
+ demo.queue(max_size=250).launch(debug=True)
155
+