fffiloni commited on
Commit
f6be7c6
·
1 Parent(s): 454eedf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -20
app.py CHANGED
@@ -60,21 +60,21 @@ def get_args():
60
  args = parser.parse_args()
61
  return args
62
 
63
- if __name__ == "__main__":
64
- args = get_args()
65
- os.makedirs(args.output_path, exist_ok=True)
66
 
67
  # Height and width should be a multiple of 32
68
- args.height = (args.height // 32) * 32
69
- args.width = (args.width // 32) * 32
70
 
71
- if args.condition == "pose":
72
  pretrained_model_or_path = "lllyasviel/ControlNet"
73
  body_model_path = hf_hub_download(pretrained_model_or_path, "annotator/ckpts/body_pose_model.pth", cache_dir="checkpoints")
74
  body_estimation = Body(body_model_path)
75
- annotator = controlnet_parser_dict[args.condition](body_estimation)
76
  else:
77
- annotator = controlnet_parser_dict[args.condition]()
78
 
79
  tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
80
  text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").to(dtype=torch.float16)
@@ -93,41 +93,43 @@ if __name__ == "__main__":
93
  pipe.to(device)
94
 
95
  generator = torch.Generator(device="cuda")
96
- generator.manual_seed(args.seed)
97
 
98
  # Step 1. Read a video
99
- video = read_video(video_path=args.video_path, video_length=args.video_length, width=args.width, height=args.height)
100
 
101
  # Save source video
102
  original_pixels = rearrange(video, "(b f) c h w -> b c f h w", b=1)
103
- save_videos_grid(original_pixels, os.path.join(args.output_path, "source_video.mp4"), rescale=True)
104
 
105
 
106
  # Step 2. Parse a video to conditional frames
107
  pil_annotation = get_annotation(video, annotator)
108
- if args.condition == "depth" and controlnet_aux.__version__ == '0.0.1':
109
  pil_annotation = [pil_annot[0] for pil_annot in pil_annotation]
110
 
111
  # Save condition video
112
  video_cond = [np.array(p).astype(np.uint8) for p in pil_annotation]
113
- imageio.mimsave(os.path.join(args.output_path, f"{args.condition}_condition.mp4"), video_cond, fps=8)
114
 
115
  # Reduce memory (optional)
116
  del annotator; torch.cuda.empty_cache()
117
 
118
  # Step 3. inference
119
 
120
- if args.is_long_video:
121
- window_size = int(np.sqrt(args.video_length))
122
- sample = pipe.generate_long_video(args.prompt + POS_PROMPT, video_length=args.video_length, frames=pil_annotation,
123
  num_inference_steps=50, smooth_steps=args.smoother_steps, window_size=window_size,
124
  generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
125
- width=args.width, height=args.height
126
  ).videos
127
  else:
128
- sample = pipe(args.prompt + POS_PROMPT, video_length=args.video_length, frames=pil_annotation,
129
  num_inference_steps=50, smooth_steps=args.smoother_steps,
130
  generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
131
- width=args.width, height=args.height
132
  ).videos
133
- save_videos_grid(sample, f"{args.output_path}/{args.prompt}.mp4")
 
 
 
60
  args = parser.parse_args()
61
  return args
62
 
63
+ def infer(prompt, video_path, output_path, condition, video_length, height, width, smoother_steps, is_long_video, seed):
64
+ #args = get_args()
65
+ #os.makedirs(args.output_path, exist_ok=True)
66
 
67
  # Height and width should be a multiple of 32
68
+ height = (height // 32) * 32
69
+ width = (width // 32) * 32
70
 
71
+ if condition == "pose":
72
  pretrained_model_or_path = "lllyasviel/ControlNet"
73
  body_model_path = hf_hub_download(pretrained_model_or_path, "annotator/ckpts/body_pose_model.pth", cache_dir="checkpoints")
74
  body_estimation = Body(body_model_path)
75
+ annotator = controlnet_parser_dict[condition](body_estimation)
76
  else:
77
+ annotator = controlnet_parser_dict[condition]()
78
 
79
  tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
80
  text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").to(dtype=torch.float16)
 
93
  pipe.to(device)
94
 
95
  generator = torch.Generator(device="cuda")
96
+ generator.manual_seed(seed)
97
 
98
  # Step 1. Read a video
99
+ video = read_video(video_path=video_path, video_length=video_length, width=width, height=height)
100
 
101
  # Save source video
102
  original_pixels = rearrange(video, "(b f) c h w -> b c f h w", b=1)
103
+ save_videos_grid(original_pixels, os.path.join(output_path, "source_video.mp4"), rescale=True)
104
 
105
 
106
  # Step 2. Parse a video to conditional frames
107
  pil_annotation = get_annotation(video, annotator)
108
+ if condition == "depth" and controlnet_aux.__version__ == '0.0.1':
109
  pil_annotation = [pil_annot[0] for pil_annot in pil_annotation]
110
 
111
  # Save condition video
112
  video_cond = [np.array(p).astype(np.uint8) for p in pil_annotation]
113
+ imageio.mimsave(os.path.join(output_path, f"{condition}_condition.mp4"), video_cond, fps=8)
114
 
115
  # Reduce memory (optional)
116
  del annotator; torch.cuda.empty_cache()
117
 
118
  # Step 3. inference
119
 
120
+ if is_long_video:
121
+ window_size = int(np.sqrt(video_length))
122
+ sample = pipe.generate_long_video(prompt + POS_PROMPT, video_length=video_length, frames=pil_annotation,
123
  num_inference_steps=50, smooth_steps=args.smoother_steps, window_size=window_size,
124
  generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
125
+ width=width, height=height
126
  ).videos
127
  else:
128
+ sample = pipe(prompt + POS_PROMPT, video_length=video_length, frames=pil_annotation,
129
  num_inference_steps=50, smooth_steps=args.smoother_steps,
130
  generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
131
+ width=width, height=height
132
  ).videos
133
+ save_videos_grid(sample, f"{output_path}/{prompt}.mp4")
134
+
135
+ return f"{output_path}/{prompt}.mp4"