Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -60,21 +60,21 @@ def get_args():
|
|
60 |
args = parser.parse_args()
|
61 |
return args
|
62 |
|
63 |
-
|
64 |
-
args = get_args()
|
65 |
-
os.makedirs(args.output_path, exist_ok=True)
|
66 |
|
67 |
# Height and width should be a multiple of 32
|
68 |
-
|
69 |
-
|
70 |
|
71 |
-
if
|
72 |
pretrained_model_or_path = "lllyasviel/ControlNet"
|
73 |
body_model_path = hf_hub_download(pretrained_model_or_path, "annotator/ckpts/body_pose_model.pth", cache_dir="checkpoints")
|
74 |
body_estimation = Body(body_model_path)
|
75 |
-
annotator = controlnet_parser_dict[
|
76 |
else:
|
77 |
-
annotator = controlnet_parser_dict[
|
78 |
|
79 |
tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
|
80 |
text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").to(dtype=torch.float16)
|
@@ -93,41 +93,43 @@ if __name__ == "__main__":
|
|
93 |
pipe.to(device)
|
94 |
|
95 |
generator = torch.Generator(device="cuda")
|
96 |
-
generator.manual_seed(
|
97 |
|
98 |
# Step 1. Read a video
|
99 |
-
video = read_video(video_path=
|
100 |
|
101 |
# Save source video
|
102 |
original_pixels = rearrange(video, "(b f) c h w -> b c f h w", b=1)
|
103 |
-
save_videos_grid(original_pixels, os.path.join(
|
104 |
|
105 |
|
106 |
# Step 2. Parse a video to conditional frames
|
107 |
pil_annotation = get_annotation(video, annotator)
|
108 |
-
if
|
109 |
pil_annotation = [pil_annot[0] for pil_annot in pil_annotation]
|
110 |
|
111 |
# Save condition video
|
112 |
video_cond = [np.array(p).astype(np.uint8) for p in pil_annotation]
|
113 |
-
imageio.mimsave(os.path.join(
|
114 |
|
115 |
# Reduce memory (optional)
|
116 |
del annotator; torch.cuda.empty_cache()
|
117 |
|
118 |
# Step 3. inference
|
119 |
|
120 |
-
if
|
121 |
-
window_size = int(np.sqrt(
|
122 |
-
sample = pipe.generate_long_video(
|
123 |
num_inference_steps=50, smooth_steps=args.smoother_steps, window_size=window_size,
|
124 |
generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
|
125 |
-
width=
|
126 |
).videos
|
127 |
else:
|
128 |
-
sample = pipe(
|
129 |
num_inference_steps=50, smooth_steps=args.smoother_steps,
|
130 |
generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
|
131 |
-
width=
|
132 |
).videos
|
133 |
-
save_videos_grid(sample, f"{
|
|
|
|
|
|
60 |
args = parser.parse_args()
|
61 |
return args
|
62 |
|
63 |
+
def infer(prompt, video_path, output_path, condition, video_length, height, width, smoother_steps, is_long_video, seed):
|
64 |
+
#args = get_args()
|
65 |
+
#os.makedirs(args.output_path, exist_ok=True)
|
66 |
|
67 |
# Height and width should be a multiple of 32
|
68 |
+
height = (height // 32) * 32
|
69 |
+
width = (width // 32) * 32
|
70 |
|
71 |
+
if condition == "pose":
|
72 |
pretrained_model_or_path = "lllyasviel/ControlNet"
|
73 |
body_model_path = hf_hub_download(pretrained_model_or_path, "annotator/ckpts/body_pose_model.pth", cache_dir="checkpoints")
|
74 |
body_estimation = Body(body_model_path)
|
75 |
+
annotator = controlnet_parser_dict[condition](body_estimation)
|
76 |
else:
|
77 |
+
annotator = controlnet_parser_dict[condition]()
|
78 |
|
79 |
tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
|
80 |
text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").to(dtype=torch.float16)
|
|
|
93 |
pipe.to(device)
|
94 |
|
95 |
generator = torch.Generator(device="cuda")
|
96 |
+
generator.manual_seed(seed)
|
97 |
|
98 |
# Step 1. Read a video
|
99 |
+
video = read_video(video_path=video_path, video_length=video_length, width=width, height=height)
|
100 |
|
101 |
# Save source video
|
102 |
original_pixels = rearrange(video, "(b f) c h w -> b c f h w", b=1)
|
103 |
+
save_videos_grid(original_pixels, os.path.join(output_path, "source_video.mp4"), rescale=True)
|
104 |
|
105 |
|
106 |
# Step 2. Parse a video to conditional frames
|
107 |
pil_annotation = get_annotation(video, annotator)
|
108 |
+
if condition == "depth" and controlnet_aux.__version__ == '0.0.1':
|
109 |
pil_annotation = [pil_annot[0] for pil_annot in pil_annotation]
|
110 |
|
111 |
# Save condition video
|
112 |
video_cond = [np.array(p).astype(np.uint8) for p in pil_annotation]
|
113 |
+
imageio.mimsave(os.path.join(output_path, f"{condition}_condition.mp4"), video_cond, fps=8)
|
114 |
|
115 |
# Reduce memory (optional)
|
116 |
del annotator; torch.cuda.empty_cache()
|
117 |
|
118 |
# Step 3. inference
|
119 |
|
120 |
+
if is_long_video:
|
121 |
+
window_size = int(np.sqrt(video_length))
|
122 |
+
sample = pipe.generate_long_video(prompt + POS_PROMPT, video_length=video_length, frames=pil_annotation,
|
123 |
num_inference_steps=50, smooth_steps=args.smoother_steps, window_size=window_size,
|
124 |
generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
|
125 |
+
width=width, height=height
|
126 |
).videos
|
127 |
else:
|
128 |
+
sample = pipe(prompt + POS_PROMPT, video_length=video_length, frames=pil_annotation,
|
129 |
num_inference_steps=50, smooth_steps=args.smoother_steps,
|
130 |
generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
|
131 |
+
width=width, height=height
|
132 |
).videos
|
133 |
+
save_videos_grid(sample, f"{output_path}/{prompt}.mp4")
|
134 |
+
|
135 |
+
return f"{output_path}/{prompt}.mp4"
|