Spaces:
Sleeping
Sleeping
add support for ZeroGPU
Browse files- app.py +12 -5
- inference_utils.py +34 -26
app.py
CHANGED
@@ -30,9 +30,9 @@ if __name__ == "__main__":
|
|
30 |
gr.Interface(
|
31 |
fn=send_to_model,
|
32 |
inputs=[
|
33 |
-
gr.Video(value=None, label="
|
34 |
-
gr.Textbox(value="", label="
|
35 |
-
gr.Textbox(value="", label="
|
36 |
gr.Slider(
|
37 |
value = 15,
|
38 |
minimum = 10,
|
@@ -41,8 +41,15 @@ if __name__ == "__main__":
|
|
41 |
label = "guidance_scale",
|
42 |
info = "The scale of the guidance field.",
|
43 |
),
|
44 |
-
gr.
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
],
|
47 |
outputs=[gr.Video(label="output", autoplay=True)],
|
48 |
allow_flagging="never",
|
|
|
30 |
gr.Interface(
|
31 |
fn=send_to_model,
|
32 |
inputs=[
|
33 |
+
gr.Video(value=None, label="source_video"),
|
34 |
+
gr.Textbox(value="", label="prompt"),
|
35 |
+
gr.Textbox(value="", label="neg_prompt"),
|
36 |
gr.Slider(
|
37 |
value = 15,
|
38 |
minimum = 10,
|
|
|
41 |
label = "guidance_scale",
|
42 |
info = "The scale of the guidance field.",
|
43 |
),
|
44 |
+
gr.Slider(
|
45 |
+
value = 16,
|
46 |
+
minimum = 8,
|
47 |
+
maximum = 32,
|
48 |
+
step = 2,
|
49 |
+
label = "video_length",
|
50 |
+
info="The length of the video, must be less than 16 frames in the online demo to avoid timeout. However, you can run the model locally to process longer videos.",
|
51 |
+
),
|
52 |
+
gr.Dropdown(value=0, choices=[0, 1], label="old_qk", info="Select 0 or 1."),
|
53 |
],
|
54 |
outputs=[gr.Video(label="output", autoplay=True)],
|
55 |
allow_flagging="never",
|
inference_utils.py
CHANGED
@@ -29,7 +29,8 @@ def init_pipeline(device):
|
|
29 |
|
30 |
pipe = FlattenPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, inverse_scheduler=inverse)
|
31 |
pipe.enable_vae_slicing()
|
32 |
-
|
|
|
33 |
pipe.to(device)
|
34 |
return pipe
|
35 |
|
@@ -44,16 +45,16 @@ pipe = init_pipeline(device)
|
|
44 |
|
45 |
|
46 |
def inference(
|
47 |
-
seed
|
48 |
-
prompt
|
49 |
-
neg_prompt
|
50 |
-
guidance_scale = 10.0,
|
51 |
-
video_length
|
52 |
-
video_path
|
53 |
-
output_dir
|
54 |
-
frame_rate
|
55 |
-
fps
|
56 |
-
old_qk
|
57 |
):
|
58 |
generator = torch.Generator(device=device)
|
59 |
generator.manual_seed(seed)
|
@@ -73,21 +74,28 @@ def inference(
|
|
73 |
|
74 |
for k in trajectories.keys():
|
75 |
trajectories[k] = trajectories[k].to(device)
|
76 |
-
sample = (
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
temp_video_name = f"/tmp/{prompt}_{neg_prompt}_{str(guidance_scale)}_{time.time()}.mp4".replace(" ", "-")
|
92 |
video_writer = imageio.get_writer(temp_video_name, fps=fps)
|
93 |
for frame in sample:
|
|
|
29 |
|
30 |
pipe = FlattenPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, inverse_scheduler=inverse)
|
31 |
pipe.enable_vae_slicing()
|
32 |
+
# xforers is not available in ZeroGPU?
|
33 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
34 |
pipe.to(device)
|
35 |
return pipe
|
36 |
|
|
|
45 |
|
46 |
|
47 |
def inference(
|
48 |
+
seed : int = 66,
|
49 |
+
prompt : str = None,
|
50 |
+
neg_prompt : str = "",
|
51 |
+
guidance_scale: float = 10.0,
|
52 |
+
video_length : int = 16,
|
53 |
+
video_path : str = None,
|
54 |
+
output_dir : str = None,
|
55 |
+
frame_rate : int = 1,
|
56 |
+
fps : int = 15,
|
57 |
+
old_qk : int = 0,
|
58 |
):
|
59 |
generator = torch.Generator(device=device)
|
60 |
generator.manual_seed(seed)
|
|
|
74 |
|
75 |
for k in trajectories.keys():
|
76 |
trajectories[k] = trajectories[k].to(device)
|
77 |
+
sample = (
|
78 |
+
pipe(
|
79 |
+
prompt,
|
80 |
+
video_length = video_length,
|
81 |
+
frames = real_frames,
|
82 |
+
num_inference_steps = sample_steps,
|
83 |
+
generator = generator,
|
84 |
+
guidance_scale = guidance_scale,
|
85 |
+
negative_prompt = neg_prompt,
|
86 |
+
width = width,
|
87 |
+
height = height,
|
88 |
+
trajs = trajectories,
|
89 |
+
output_dir = "tmp/",
|
90 |
+
inject_step = inject_step,
|
91 |
+
old_qk = old_qk,
|
92 |
+
)
|
93 |
+
.videos[0]
|
94 |
+
.permute(1, 2, 3, 0)
|
95 |
+
.cpu()
|
96 |
+
.numpy()
|
97 |
+
* 255
|
98 |
+
).astype(np.uint8)
|
99 |
temp_video_name = f"/tmp/{prompt}_{neg_prompt}_{str(guidance_scale)}_{time.time()}.mp4".replace(" ", "-")
|
100 |
video_writer = imageio.get_writer(temp_video_name, fps=fps)
|
101 |
for frame in sample:
|