AngeT10 commited on
Commit
876a645
1 Parent(s): 48e22a6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +309 -0
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("git clone https://github.com/showlab/Show-1.git")
4
+
5
+ import gradio as gr
6
+ import torch
7
+ from diffusers.utils import export_to_video
8
+
9
+ import os
10
+ from PIL import Image
11
+
12
+ import torch.nn.functional as F
13
+
14
+ from diffusers import IFSuperResolutionPipeline, VideoToVideoSDPipeline
15
+ from diffusers.utils import export_to_video
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+
18
+ from showone.pipelines import TextToVideoIFPipeline, TextToVideoIFInterpPipeline, TextToVideoIFSuperResolutionPipeline
19
+ from showone.pipelines.pipeline_t2v_base_pixel import tensor2vid
20
+ from showone.pipelines.pipeline_t2v_sr_pixel_cond import TextToVideoIFSuperResolutionPipeline_Cond
21
+
22
+
23
+ # Base Model
24
+ pretrained_model_path = "showlab/show-1-base"
25
+ pipe_base = TextToVideoIFPipeline.from_pretrained(
26
+ pretrained_model_path,
27
+ torch_dtype=torch.float16,
28
+ variant="fp16"
29
+ )
30
+ pipe_base.enable_model_cpu_offload()
31
+
32
+ # Interpolation Model
33
+ pretrained_model_path = "showlab/show-1-interpolation"
34
+ pipe_interp_1 = TextToVideoIFInterpPipeline.from_pretrained(
35
+ pretrained_model_path,
36
+ text_encoder=None,
37
+ torch_dtype=torch.float16,
38
+ variant="fp16"
39
+ )
40
+ pipe_interp_1.enable_model_cpu_offload()
41
+
42
+ # Super-Resolution Model 1
43
+ # Image super-resolution model from DeepFloyd https://huggingface.co/DeepFloyd/IF-II-L-v1.0
44
+ pretrained_model_path = "DeepFloyd/IF-II-L-v1.0"
45
+ pipe_sr_1_image = IFSuperResolutionPipeline.from_pretrained(
46
+ pretrained_model_path,
47
+ text_encoder=None,
48
+ torch_dtype=torch.float16,
49
+ variant="fp16",
50
+ )
51
+ pipe_sr_1_image.enable_model_cpu_offload()
52
+
53
+ pretrained_model_path = "showlab/show-1-sr1"
54
+ pipe_sr_1_cond = TextToVideoIFSuperResolutionPipeline_Cond.from_pretrained(
55
+ pretrained_model_path,
56
+ text_encoder=None,
57
+ torch_dtype=torch.float16
58
+ )
59
+ pipe_sr_1_cond.enable_model_cpu_offload()
60
+
61
+ # Super-Resolution Model 2
62
+ pretrained_model_path = "showlab/show-1-sr2"
63
+ pipe_sr_2 = VideoToVideoSDPipeline.from_pretrained(
64
+ pretrained_model_path,
65
+ torch_dtype=torch.float16
66
+ )
67
+ pipe_sr_2.enable_model_cpu_offload()
68
+ pipe_sr_2.enable_vae_slicing()
69
+
70
+ output_dir = "./outputs"
71
+ os.makedirs(output_dir, exist_ok=True)
72
+
73
+ def infer(prompt):
74
+ print(prompt)
75
+ negative_prompt = "low resolution, blur"
76
+
77
+ # Text embeds
78
+ prompt_embeds, negative_embeds = pipe_base.encode_prompt(prompt)
79
+
80
+ # Keyframes generation (8x64x40, 2fps)
81
+ video_frames = pipe_base(
82
+ prompt_embeds=prompt_embeds,
83
+ negative_prompt_embeds=negative_embeds,
84
+ num_frames=8,
85
+ height=40,
86
+ width=64,
87
+ num_inference_steps=75,
88
+ guidance_scale=9.0,
89
+ output_type="pt"
90
+ ).frames
91
+
92
+ # Frame interpolation (8x64x40, 2fps -> 29x64x40, 7.5fps)
93
+ bsz, channel, num_frames, height, width = video_frames.shape
94
+ new_num_frames = 3 * (num_frames - 1) + num_frames
95
+ new_video_frames = torch.zeros((bsz, channel, new_num_frames, height, width),
96
+ dtype=video_frames.dtype, device=video_frames.device)
97
+ new_video_frames[:, :, torch.arange(0, new_num_frames, 4), ...] = video_frames
98
+ init_noise = randn_tensor((bsz, channel, 5, height, width), dtype=video_frames.dtype,
99
+ device=video_frames.device)
100
+
101
+ for i in range(num_frames - 1):
102
+ batch_i = torch.zeros((bsz, channel, 5, height, width), dtype=video_frames.dtype, device=video_frames.device)
103
+ batch_i[:, :, 0, ...] = video_frames[:, :, i, ...]
104
+ batch_i[:, :, -1, ...] = video_frames[:, :, i + 1, ...]
105
+ batch_i = pipe_interp_1(
106
+ pixel_values=batch_i,
107
+ prompt_embeds=prompt_embeds,
108
+ negative_prompt_embeds=negative_embeds,
109
+ num_frames=batch_i.shape[2],
110
+ height=40,
111
+ width=64,
112
+ num_inference_steps=50,
113
+ guidance_scale=4.0,
114
+ output_type="pt",
115
+ init_noise=init_noise,
116
+ cond_interpolation=True,
117
+ ).frames
118
+
119
+ new_video_frames[:, :, i * 4:i * 4 + 5, ...] = batch_i
120
+
121
+ video_frames = new_video_frames
122
+
123
+ # Super-resolution 1 (29x64x40 -> 29x256x160)
124
+ bsz, channel, num_frames, height, width = video_frames.shape
125
+ window_size, stride = 8, 7
126
+ new_video_frames = torch.zeros(
127
+ (bsz, channel, num_frames, height * 4, width * 4),
128
+ dtype=video_frames.dtype,
129
+ device=video_frames.device)
130
+ for i in range(0, num_frames - window_size + 1, stride):
131
+ batch_i = video_frames[:, :, i:i + window_size, ...]
132
+
133
+ if i == 0:
134
+ first_frame_cond = pipe_sr_1_image(
135
+ image=video_frames[:, :, 0, ...],
136
+ prompt_embeds=prompt_embeds,
137
+ negative_prompt_embeds=negative_embeds,
138
+ height=height * 4,
139
+ width=width * 4,
140
+ num_inference_steps=50,
141
+ guidance_scale=4.0,
142
+ noise_level=150,
143
+ output_type="pt"
144
+ ).images
145
+ first_frame_cond = first_frame_cond.unsqueeze(2)
146
+ else:
147
+ first_frame_cond = new_video_frames[:, :, i:i + 1, ...]
148
+
149
+ batch_i = pipe_sr_1_cond(
150
+ image=batch_i,
151
+ prompt_embeds=prompt_embeds,
152
+ negative_prompt_embeds=negative_embeds,
153
+ first_frame_cond=first_frame_cond,
154
+ height=height * 4,
155
+ width=width * 4,
156
+ num_inference_steps=50,
157
+ guidance_scale=7.0,
158
+ noise_level=250,
159
+ output_type="pt"
160
+ ).frames
161
+ new_video_frames[:, :, i:i + window_size, ...] = batch_i
162
+
163
+ video_frames = new_video_frames
164
+
165
+ # Super-resolution 2 (29x256x160 -> 29x576x320)
166
+ video_frames = [Image.fromarray(frame).resize((576, 320)) for frame in tensor2vid(video_frames.clone())]
167
+ video_frames = pipe_sr_2(
168
+ prompt,
169
+ negative_prompt=negative_prompt,
170
+ video=video_frames,
171
+ strength=0.8,
172
+ num_inference_steps=50,
173
+ ).frames
174
+
175
+ video_path = export_to_video(video_frames, f"{output_dir}/{prompt[:200]}.mp4")
176
+ print(video_path)
177
+ return video_path
178
+
179
+ css = """
180
+ #col-container {max-width: 510px; margin-left: auto; margin-right: auto;}
181
+ a {text-decoration-line: underline; font-weight: 600;}
182
+ .animate-spin {
183
+ animation: spin 1s linear infinite;
184
+ }
185
+
186
+ @keyframes spin {
187
+ from {
188
+ transform: rotate(0deg);
189
+ }
190
+ to {
191
+ transform: rotate(360deg);
192
+ }
193
+ }
194
+
195
+ #share-btn-container {
196
+ display: flex;
197
+ padding-left: 0.5rem !important;
198
+ padding-right: 0.5rem !important;
199
+ background-color: #000000;
200
+ justify-content: center;
201
+ align-items: center;
202
+ border-radius: 9999px !important;
203
+ max-width: 15rem;
204
+ height: 36px;
205
+ }
206
+
207
+ div#share-btn-container > div {
208
+ flex-direction: row;
209
+ background: black;
210
+ align-items: center;
211
+ }
212
+
213
+ #share-btn-container:hover {
214
+ background-color: #060606;
215
+ }
216
+
217
+ #share-btn {
218
+ all: initial;
219
+ color: #ffffff;
220
+ font-weight: 600;
221
+ cursor:pointer;
222
+ font-family: 'IBM Plex Sans', sans-serif;
223
+ margin-left: 0.5rem !important;
224
+ padding-top: 0.5rem !important;
225
+ padding-bottom: 0.5rem !important;
226
+ right:0;
227
+ }
228
+
229
+ #share-btn * {
230
+ all: unset;
231
+ }
232
+
233
+ #share-btn-container div:nth-child(-n+2){
234
+ width: auto !important;
235
+ min-height: 0px !important;
236
+ }
237
+
238
+ #share-btn-container .wrap {
239
+ display: none !important;
240
+ }
241
+
242
+ #share-btn-container.hidden {
243
+ display: none!important;
244
+ }
245
+ img[src*='#center'] {
246
+ display: inline-block;
247
+ margin: unset;
248
+ }
249
+
250
+ .footer {
251
+ margin-bottom: 45px;
252
+ margin-top: 10px;
253
+ text-align: center;
254
+ border-bottom: 1px solid #e5e5e5;
255
+ }
256
+ .footer>p {
257
+ font-size: .8rem;
258
+ display: inline-block;
259
+ padding: 0 10px;
260
+ transform: translateY(10px);
261
+ background: white;
262
+ }
263
+ .dark .footer {
264
+ border-color: #303030;
265
+ }
266
+ .dark .footer>p {
267
+ background: #0b0f19;
268
+ }
269
+ """
270
+
271
+ with gr.Blocks(css=css) as demo:
272
+ with gr.Column(elem_id="col-container"):
273
+ gr.Markdown(
274
+ """
275
+ <h1 style="text-align: center;">Show-1 Text-to-Video</h1>
276
+ <p style="text-align: center;">
277
+ A text-to-video generation model that marries the strength and alleviates the weakness of pixel-based and latent-based VDMs. <br />
278
+ </p>
279
+
280
+ <p style="text-align: center;">
281
+ <a href="https://arxiv.org/abs/2309.15818" target="_blank">Paper</a> |
282
+ <a href="https://showlab.github.io/Show-1" target="_blank">Project Page</a> |
283
+ <a href="https://github.com/showlab/Show-1" target="_blank">Github</a>
284
+ </p>
285
+
286
+ """
287
+ )
288
+
289
+ prompt_in = gr.Textbox(label="Prompt", placeholder="A panda taking a selfie", elem_id="prompt-in")
290
+ #neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in")
291
+ #inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False)
292
+ submit_btn = gr.Button("Submit")
293
+ video_result = gr.Video(label="Video Output", elem_id="video-output")
294
+
295
+ gr.HTML("""
296
+ <div class="footer">
297
+ <p>
298
+ Demo adapted from <a href="https://huggingface.co/spaces/fffiloni/zeroscope" target="_blank">zeroscope</a>
299
+ by 🤗 <a href="https://twitter.com/fffiloni" target="_blank">Sylvain Filoni</a>
300
+ </p>
301
+ </div>
302
+ """)
303
+
304
+ submit_btn.click(fn=infer,
305
+ inputs=[prompt_in],
306
+ outputs=[video_result],
307
+ api_name="show-1")
308
+
309
+ demo.queue(max_size=12).launch(show_api=True)