sky24h commited on
Commit
fae070b
·
1 Parent(s): 299fe30

add support for ZeroGPU

Browse files
Files changed (2) hide show
  1. app.py +12 -5
  2. inference_utils.py +34 -26
app.py CHANGED
@@ -30,9 +30,9 @@ if __name__ == "__main__":
30
  gr.Interface(
31
  fn=send_to_model,
32
  inputs=[
33
- gr.Video(value=None, label="Source Image"),
34
- gr.Textbox(value="", label="Prompt"),
35
- gr.Textbox(value="", label="Negative Prompt"),
36
  gr.Slider(
37
  value = 15,
38
  minimum = 10,
@@ -41,8 +41,15 @@ if __name__ == "__main__":
41
  label = "guidance_scale",
42
  info = "The scale of the guidance field.",
43
  ),
44
- gr.Textbox(value=16, label="Video Length", info="The length of the video, must be less than 16 frames in the online demo to avoid timeout. However, you can run the model locally to process longer videos."),
45
- gr.Dropdown(value=0, choices=[0, 1], label="Choose Option", info="Select 0 or 1."),
 
 
 
 
 
 
 
46
  ],
47
  outputs=[gr.Video(label="output", autoplay=True)],
48
  allow_flagging="never",
 
30
  gr.Interface(
31
  fn=send_to_model,
32
  inputs=[
33
+ gr.Video(value=None, label="source_video"),
34
+ gr.Textbox(value="", label="prompt"),
35
+ gr.Textbox(value="", label="neg_prompt"),
36
  gr.Slider(
37
  value = 15,
38
  minimum = 10,
 
41
  label = "guidance_scale",
42
  info = "The scale of the guidance field.",
43
  ),
44
+ gr.Slider(
45
+ value = 16,
46
+ minimum = 8,
47
+ maximum = 32,
48
+ step = 2,
49
+ label = "video_length",
50
+ info="The length of the video, must be less than 16 frames in the online demo to avoid timeout. However, you can run the model locally to process longer videos.",
51
+ ),
52
+ gr.Dropdown(value=0, choices=[0, 1], label="old_qk", info="Select 0 or 1."),
53
  ],
54
  outputs=[gr.Video(label="output", autoplay=True)],
55
  allow_flagging="never",
inference_utils.py CHANGED
@@ -29,7 +29,8 @@ def init_pipeline(device):
29
 
30
  pipe = FlattenPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, inverse_scheduler=inverse)
31
  pipe.enable_vae_slicing()
32
- pipe.enable_xformers_memory_efficient_attention()
 
33
  pipe.to(device)
34
  return pipe
35
 
@@ -44,16 +45,16 @@ pipe = init_pipeline(device)
44
 
45
 
46
  def inference(
47
- seed = 66,
48
- prompt = None,
49
- neg_prompt = "",
50
- guidance_scale = 10.0,
51
- video_length = 16,
52
- video_path = None,
53
- output_dir = None,
54
- frame_rate = 1,
55
- fps = 15,
56
- old_qk = 0,
57
  ):
58
  generator = torch.Generator(device=device)
59
  generator.manual_seed(seed)
@@ -73,21 +74,28 @@ def inference(
73
 
74
  for k in trajectories.keys():
75
  trajectories[k] = trajectories[k].to(device)
76
- sample = (pipe(
77
- prompt,
78
- video_length = video_length,
79
- frames = real_frames,
80
- num_inference_steps = sample_steps,
81
- generator = generator,
82
- guidance_scale = guidance_scale,
83
- negative_prompt = neg_prompt,
84
- width = width,
85
- height = height,
86
- trajs = trajectories,
87
- output_dir = "tmp/",
88
- inject_step = inject_step,
89
- old_qk = old_qk,
90
- ).videos[0].permute(1, 2, 3, 0).cpu().numpy() * 255).astype(np.uint8)
 
 
 
 
 
 
 
91
  temp_video_name = f"/tmp/{prompt}_{neg_prompt}_{str(guidance_scale)}_{time.time()}.mp4".replace(" ", "-")
92
  video_writer = imageio.get_writer(temp_video_name, fps=fps)
93
  for frame in sample:
 
29
 
30
  pipe = FlattenPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, inverse_scheduler=inverse)
31
  pipe.enable_vae_slicing()
32
+ # xforers is not available in ZeroGPU?
33
+ # pipe.enable_xformers_memory_efficient_attention()
34
  pipe.to(device)
35
  return pipe
36
 
 
45
 
46
 
47
  def inference(
48
+ seed : int = 66,
49
+ prompt : str = None,
50
+ neg_prompt : str = "",
51
+ guidance_scale: float = 10.0,
52
+ video_length : int = 16,
53
+ video_path : str = None,
54
+ output_dir : str = None,
55
+ frame_rate : int = 1,
56
+ fps : int = 15,
57
+ old_qk : int = 0,
58
  ):
59
  generator = torch.Generator(device=device)
60
  generator.manual_seed(seed)
 
74
 
75
  for k in trajectories.keys():
76
  trajectories[k] = trajectories[k].to(device)
77
+ sample = (
78
+ pipe(
79
+ prompt,
80
+ video_length = video_length,
81
+ frames = real_frames,
82
+ num_inference_steps = sample_steps,
83
+ generator = generator,
84
+ guidance_scale = guidance_scale,
85
+ negative_prompt = neg_prompt,
86
+ width = width,
87
+ height = height,
88
+ trajs = trajectories,
89
+ output_dir = "tmp/",
90
+ inject_step = inject_step,
91
+ old_qk = old_qk,
92
+ )
93
+ .videos[0]
94
+ .permute(1, 2, 3, 0)
95
+ .cpu()
96
+ .numpy()
97
+ * 255
98
+ ).astype(np.uint8)
99
  temp_video_name = f"/tmp/{prompt}_{neg_prompt}_{str(guidance_scale)}_{time.time()}.mp4".replace(" ", "-")
100
  video_writer = imageio.get_writer(temp_video_name, fps=fps)
101
  for frame in sample: