Spaces:
Running
on
Zero
Running
on
Zero
Update
Browse files
app.py
CHANGED
@@ -10,9 +10,7 @@ import numpy as np
|
|
10 |
import spaces
|
11 |
import torch
|
12 |
import torchvision
|
13 |
-
from diffusers.utils.import_utils import is_xformers_available
|
14 |
from huggingface_hub import snapshot_download
|
15 |
-
from packaging import version
|
16 |
from PIL import Image
|
17 |
from scipy.interpolate import PchipInterpolator
|
18 |
|
@@ -39,55 +37,40 @@ snapshot_download(
|
|
39 |
)
|
40 |
|
41 |
|
42 |
-
|
43 |
-
|
|
|
44 |
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
parser.add_argument("--min_guidance_scale", type=float, default=1.0)
|
48 |
-
parser.add_argument("--max_guidance_scale", type=float, default=3.0)
|
49 |
-
parser.add_argument("--middle_max_guidance", type=int, default=0, choices=[0, 1])
|
50 |
-
parser.add_argument("--with_control", type=int, default=1, choices=[0, 1])
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
parser.add_argument(
|
61 |
-
"--model",
|
62 |
-
type=str,
|
63 |
-
default="checkpoints/framer_512x320",
|
64 |
-
help="Path to model.",
|
65 |
-
)
|
66 |
-
|
67 |
-
parser.add_argument("--output_dir", type=str, default="gradio_demo/outputs", help="Path to the output video.")
|
68 |
-
|
69 |
-
parser.add_argument("--seed", type=int, default=42, help="random seed.")
|
70 |
-
|
71 |
-
parser.add_argument("--noise_aug", type=float, default=0.02)
|
72 |
-
|
73 |
-
parser.add_argument("--num_frames", type=int, default=14)
|
74 |
-
parser.add_argument("--frame_interval", type=int, default=2)
|
75 |
-
|
76 |
-
parser.add_argument("--width", type=int, default=512)
|
77 |
-
parser.add_argument("--height", type=int, default=320)
|
78 |
-
|
79 |
-
parser.add_argument(
|
80 |
-
"--num_workers",
|
81 |
-
type=int,
|
82 |
-
default=0,
|
83 |
-
help=(
|
84 |
-
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
85 |
-
),
|
86 |
-
)
|
87 |
-
|
88 |
-
args = parser.parse_args()
|
89 |
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
|
93 |
def interpolate_trajectory(points, n_points):
|
@@ -164,7 +147,7 @@ def get_vis_image(
|
|
164 |
vis_img = new_img.copy()
|
165 |
# ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
|
166 |
|
167 |
-
if idxx >=
|
168 |
break
|
169 |
|
170 |
# for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
|
@@ -363,187 +346,6 @@ def validate_and_convert_image(image, target_size=(512, 512)):
|
|
363 |
return image
|
364 |
|
365 |
|
366 |
-
class Drag:
|
367 |
-
|
368 |
-
@spaces.GPU
|
369 |
-
def __init__(self, device, args, height, width, model_length, dtype=torch.float16, use_sift=False):
|
370 |
-
self.device = device
|
371 |
-
self.dtype = dtype
|
372 |
-
|
373 |
-
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
374 |
-
os.path.join(args.model, "unet"),
|
375 |
-
torch_dtype=torch.float16,
|
376 |
-
low_cpu_mem_usage=True,
|
377 |
-
custom_resume=True,
|
378 |
-
)
|
379 |
-
unet = unet.to(device, dtype)
|
380 |
-
|
381 |
-
controlnet = ControlNetSVDModel.from_pretrained(
|
382 |
-
os.path.join(args.model, "controlnet"),
|
383 |
-
)
|
384 |
-
controlnet = controlnet.to(device, dtype)
|
385 |
-
|
386 |
-
if is_xformers_available():
|
387 |
-
import xformers
|
388 |
-
|
389 |
-
xformers_version = version.parse(xformers.__version__)
|
390 |
-
unet.enable_xformers_memory_efficient_attention()
|
391 |
-
# controlnet.enable_xformers_memory_efficient_attention()
|
392 |
-
else:
|
393 |
-
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
394 |
-
|
395 |
-
pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
|
396 |
-
"checkpoints/stable-video-diffusion-img2vid-xt",
|
397 |
-
unet=unet,
|
398 |
-
controlnet=controlnet,
|
399 |
-
low_cpu_mem_usage=False,
|
400 |
-
torch_dtype=torch.float16,
|
401 |
-
variant="fp16",
|
402 |
-
local_files_only=True,
|
403 |
-
)
|
404 |
-
pipe.to(device)
|
405 |
-
|
406 |
-
self.pipeline = pipe
|
407 |
-
# self.pipeline.enable_model_cpu_offload()
|
408 |
-
|
409 |
-
self.height = height
|
410 |
-
self.width = width
|
411 |
-
self.args = args
|
412 |
-
self.model_length = model_length
|
413 |
-
self.use_sift = use_sift
|
414 |
-
|
415 |
-
@spaces.GPU
|
416 |
-
def run(self, first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
|
417 |
-
original_width, original_height = 512, 320 # TODO
|
418 |
-
|
419 |
-
# load_image
|
420 |
-
image = Image.open(first_frame_path).convert("RGB")
|
421 |
-
width, height = image.size
|
422 |
-
image = image.resize((self.width, self.height))
|
423 |
-
|
424 |
-
image_end = Image.open(last_frame_path).convert("RGB")
|
425 |
-
image_end = image_end.resize((self.width, self.height))
|
426 |
-
|
427 |
-
input_all_points = tracking_points
|
428 |
-
|
429 |
-
sift_track_update = False
|
430 |
-
anchor_points_flag = None
|
431 |
-
|
432 |
-
if (len(input_all_points) == 0) and self.use_sift:
|
433 |
-
sift_track_update = True
|
434 |
-
controlnet_cond_scale = 0.5
|
435 |
-
|
436 |
-
from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory
|
437 |
-
from models_diffusers.sift_match import sift_match
|
438 |
-
|
439 |
-
output_file_sift = os.path.join(args.output_dir, "sift.png")
|
440 |
-
|
441 |
-
# (f, topk, 2), f=2 (before interpolation)
|
442 |
-
pred_tracks = sift_match(
|
443 |
-
image,
|
444 |
-
image_end,
|
445 |
-
thr=0.5,
|
446 |
-
topk=5,
|
447 |
-
method="random",
|
448 |
-
output_path=output_file_sift,
|
449 |
-
)
|
450 |
-
|
451 |
-
if pred_tracks is not None:
|
452 |
-
# interpolate the tracks, following draganything gradio demo
|
453 |
-
pred_tracks = sift_interpolate_trajectory(pred_tracks, num_frames=self.model_length)
|
454 |
-
|
455 |
-
anchor_points_flag = torch.zeros((self.model_length, pred_tracks.shape[1])).to(pred_tracks.device)
|
456 |
-
anchor_points_flag[0] = 1
|
457 |
-
anchor_points_flag[-1] = 1
|
458 |
-
|
459 |
-
pred_tracks = pred_tracks.permute(1, 0, 2) # (num_points, num_frames, 2)
|
460 |
-
|
461 |
-
else:
|
462 |
-
|
463 |
-
resized_all_points = [
|
464 |
-
tuple(
|
465 |
-
[
|
466 |
-
tuple([int(e1[0] * self.width / original_width), int(e1[1] * self.height / original_height)])
|
467 |
-
for e1 in e
|
468 |
-
]
|
469 |
-
)
|
470 |
-
for e in input_all_points
|
471 |
-
]
|
472 |
-
|
473 |
-
# a list of num_tracks tuples, each tuple contains a track with several points, represented as (x, y)
|
474 |
-
# in image w & h scale
|
475 |
-
|
476 |
-
for idx, splited_track in enumerate(resized_all_points):
|
477 |
-
if len(splited_track) == 0:
|
478 |
-
warnings.warn("running without point trajectory control")
|
479 |
-
continue
|
480 |
-
|
481 |
-
if len(splited_track) == 1: # stationary point
|
482 |
-
displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
|
483 |
-
splited_track = tuple([splited_track[0], displacement_point])
|
484 |
-
# interpolate the track
|
485 |
-
splited_track = interpolate_trajectory(splited_track, self.model_length)
|
486 |
-
splited_track = splited_track[: self.model_length]
|
487 |
-
resized_all_points[idx] = splited_track
|
488 |
-
|
489 |
-
pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
|
490 |
-
|
491 |
-
vis_images = get_vis_image(
|
492 |
-
target_size=(self.args.height, self.args.width),
|
493 |
-
points=pred_tracks,
|
494 |
-
num_frames=self.model_length,
|
495 |
-
)
|
496 |
-
|
497 |
-
if len(pred_tracks.shape) != 3:
|
498 |
-
print("pred_tracks.shape", pred_tracks.shape)
|
499 |
-
with_control = False
|
500 |
-
controlnet_cond_scale = 0.0
|
501 |
-
else:
|
502 |
-
with_control = True
|
503 |
-
pred_tracks = pred_tracks.permute(1, 0, 2).to(self.device, self.dtype) # (num_frames, num_points, 2)
|
504 |
-
|
505 |
-
point_embedding = None
|
506 |
-
video_frames = self.pipeline(
|
507 |
-
image,
|
508 |
-
image_end,
|
509 |
-
# trajectory control
|
510 |
-
with_control=with_control,
|
511 |
-
point_tracks=pred_tracks,
|
512 |
-
point_embedding=point_embedding,
|
513 |
-
with_id_feature=False,
|
514 |
-
controlnet_cond_scale=controlnet_cond_scale,
|
515 |
-
# others
|
516 |
-
num_frames=14,
|
517 |
-
width=width,
|
518 |
-
height=height,
|
519 |
-
# decode_chunk_size=8,
|
520 |
-
# generator=generator,
|
521 |
-
motion_bucket_id=motion_bucket_id,
|
522 |
-
fps=7,
|
523 |
-
num_inference_steps=30,
|
524 |
-
# track
|
525 |
-
sift_track_update=sift_track_update,
|
526 |
-
anchor_points_flag=anchor_points_flag,
|
527 |
-
).frames[0]
|
528 |
-
|
529 |
-
vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
|
530 |
-
vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
|
531 |
-
vis_images = [Image.fromarray(img) for img in vis_images]
|
532 |
-
|
533 |
-
# video_frames = [img for sublist in video_frames for img in sublist]
|
534 |
-
val_save_dir = os.path.join(args.output_dir, "vis_gif.gif")
|
535 |
-
save_gifs_side_by_side(
|
536 |
-
video_frames,
|
537 |
-
vis_images[: self.model_length],
|
538 |
-
val_save_dir,
|
539 |
-
target_size=(self.width, self.height),
|
540 |
-
duration=110,
|
541 |
-
point_tracks=pred_tracks,
|
542 |
-
)
|
543 |
-
|
544 |
-
return val_save_dir
|
545 |
-
|
546 |
-
|
547 |
def reset_states(first_frame_path, last_frame_path, tracking_points):
|
548 |
first_frame_path = None
|
549 |
last_frame_path = None
|
@@ -561,7 +363,7 @@ def preprocess_image(image):
|
|
561 |
# image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
|
562 |
image_pil = image_pil.resize((512, 320), Image.BILINEAR)
|
563 |
|
564 |
-
first_frame_path = os.path.join(
|
565 |
|
566 |
image_pil.save(first_frame_path)
|
567 |
|
@@ -578,7 +380,7 @@ def preprocess_image_end(image_end):
|
|
578 |
# image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
|
579 |
image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
|
580 |
|
581 |
-
last_frame_path = os.path.join(
|
582 |
|
583 |
image_end_pil.save(last_frame_path)
|
584 |
|
@@ -692,7 +494,7 @@ def add_tracking_points(
|
|
692 |
transparent_layer = 0
|
693 |
for idx, track in enumerate(tracking_points):
|
694 |
# mask = cv2.imread(
|
695 |
-
# os.path.join(
|
696 |
# )
|
697 |
mask = np.zeros((320, 512, 3))
|
698 |
color = color_list[idx + 1]
|
@@ -737,10 +539,136 @@ def add_tracking_points(
|
|
737 |
return tracking_points, trajectory_map, trajectory_map_end
|
738 |
|
739 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
740 |
if __name__ == "__main__":
|
741 |
|
742 |
-
|
743 |
-
ensure_dirname(args.output_dir)
|
744 |
|
745 |
color_list = []
|
746 |
for i in range(20):
|
@@ -771,8 +699,6 @@ if __name__ == "__main__":
|
|
771 |
3. Interpolate the images (according the path) with a click on "Run" button. <br>"""
|
772 |
)
|
773 |
|
774 |
-
# device, args, height, width, model_length
|
775 |
-
Framer = Drag("cuda", args, 320, 512, 14)
|
776 |
first_frame_path = gr.State()
|
777 |
last_frame_path = gr.State()
|
778 |
tracking_points = gr.State([])
|
@@ -898,7 +824,7 @@ if __name__ == "__main__":
|
|
898 |
)
|
899 |
|
900 |
run_button.click(
|
901 |
-
fn=
|
902 |
inputs=[first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id],
|
903 |
outputs=output_video,
|
904 |
)
|
|
|
10 |
import spaces
|
11 |
import torch
|
12 |
import torchvision
|
|
|
13 |
from huggingface_hub import snapshot_download
|
|
|
14 |
from PIL import Image
|
15 |
from scipy.interpolate import PchipInterpolator
|
16 |
|
|
|
37 |
)
|
38 |
|
39 |
|
40 |
+
model_id = "checkpoints/framer_512x320"
|
41 |
+
device = "cuda"
|
42 |
+
dtype = torch.float16
|
43 |
|
44 |
+
OUTPUT_DIR = "gradio_demo/outputs"
|
45 |
+
HEIGHT = 320
|
46 |
+
WIDTH = 512
|
47 |
+
MODEL_LENGTH = 14
|
48 |
+
USE_SIFT = False
|
49 |
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
52 |
+
os.path.join(model_id, "unet"),
|
53 |
+
torch_dtype=torch.float16,
|
54 |
+
low_cpu_mem_usage=True,
|
55 |
+
custom_resume=True,
|
56 |
+
)
|
57 |
+
unet = unet.to(device, dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
controlnet = ControlNetSVDModel.from_pretrained(
|
60 |
+
os.path.join(model_id, "controlnet"),
|
61 |
+
)
|
62 |
+
controlnet = controlnet.to(device, dtype)
|
63 |
+
|
64 |
+
pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
|
65 |
+
"checkpoints/stable-video-diffusion-img2vid-xt",
|
66 |
+
unet=unet,
|
67 |
+
controlnet=controlnet,
|
68 |
+
low_cpu_mem_usage=False,
|
69 |
+
torch_dtype=torch.float16,
|
70 |
+
variant="fp16",
|
71 |
+
local_files_only=True,
|
72 |
+
)
|
73 |
+
pipe.to(device)
|
74 |
|
75 |
|
76 |
def interpolate_trajectory(points, n_points):
|
|
|
147 |
vis_img = new_img.copy()
|
148 |
# ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
|
149 |
|
150 |
+
if idxx >= num_frames:
|
151 |
break
|
152 |
|
153 |
# for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
|
|
|
346 |
return image
|
347 |
|
348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
def reset_states(first_frame_path, last_frame_path, tracking_points):
|
350 |
first_frame_path = None
|
351 |
last_frame_path = None
|
|
|
363 |
# image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
|
364 |
image_pil = image_pil.resize((512, 320), Image.BILINEAR)
|
365 |
|
366 |
+
first_frame_path = os.path.join(OUTPUT_DIR, f"first_frame_{str(uuid.uuid4())[:4]}.png")
|
367 |
|
368 |
image_pil.save(first_frame_path)
|
369 |
|
|
|
380 |
# image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
|
381 |
image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
|
382 |
|
383 |
+
last_frame_path = os.path.join(OUTPUT_DIR, f"last_frame_{str(uuid.uuid4())[:4]}.png")
|
384 |
|
385 |
image_end_pil.save(last_frame_path)
|
386 |
|
|
|
494 |
transparent_layer = 0
|
495 |
for idx, track in enumerate(tracking_points):
|
496 |
# mask = cv2.imread(
|
497 |
+
# os.path.join(OUTPUT_DIR, f"mask_{idx+1}.jpg")
|
498 |
# )
|
499 |
mask = np.zeros((320, 512, 3))
|
500 |
color = color_list[idx + 1]
|
|
|
539 |
return tracking_points, trajectory_map, trajectory_map_end
|
540 |
|
541 |
|
542 |
+
@spaces.GPU
|
543 |
+
def run(first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
|
544 |
+
original_width, original_height = 512, 320 # TODO
|
545 |
+
|
546 |
+
# load_image
|
547 |
+
image = Image.open(first_frame_path).convert("RGB")
|
548 |
+
width, height = image.size
|
549 |
+
image = image.resize((WIDTH, HEIGHT))
|
550 |
+
|
551 |
+
image_end = Image.open(last_frame_path).convert("RGB")
|
552 |
+
image_end = image_end.resize((WIDTH, HEIGHT))
|
553 |
+
|
554 |
+
input_all_points = tracking_points
|
555 |
+
|
556 |
+
sift_track_update = False
|
557 |
+
anchor_points_flag = None
|
558 |
+
|
559 |
+
if (len(input_all_points) == 0) and USE_SIFT:
|
560 |
+
sift_track_update = True
|
561 |
+
controlnet_cond_scale = 0.5
|
562 |
+
|
563 |
+
from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory
|
564 |
+
from models_diffusers.sift_match import sift_match
|
565 |
+
|
566 |
+
output_file_sift = os.path.join(OUTPUT_DIR, "sift.png")
|
567 |
+
|
568 |
+
# (f, topk, 2), f=2 (before interpolation)
|
569 |
+
pred_tracks = sift_match(
|
570 |
+
image,
|
571 |
+
image_end,
|
572 |
+
thr=0.5,
|
573 |
+
topk=5,
|
574 |
+
method="random",
|
575 |
+
output_path=output_file_sift,
|
576 |
+
)
|
577 |
+
|
578 |
+
if pred_tracks is not None:
|
579 |
+
# interpolate the tracks, following draganything gradio demo
|
580 |
+
pred_tracks = sift_interpolate_trajectory(pred_tracks, num_frames=MODEL_LENGTH)
|
581 |
+
|
582 |
+
anchor_points_flag = torch.zeros((MODEL_LENGTH, pred_tracks.shape[1])).to(pred_tracks.device)
|
583 |
+
anchor_points_flag[0] = 1
|
584 |
+
anchor_points_flag[-1] = 1
|
585 |
+
|
586 |
+
pred_tracks = pred_tracks.permute(1, 0, 2) # (num_points, num_frames, 2)
|
587 |
+
|
588 |
+
else:
|
589 |
+
|
590 |
+
resized_all_points = [
|
591 |
+
tuple([tuple([int(e1[0] * WIDTH / original_width), int(e1[1] * HEIGHT / original_height)]) for e1 in e])
|
592 |
+
for e in input_all_points
|
593 |
+
]
|
594 |
+
|
595 |
+
# a list of num_tracks tuples, each tuple contains a track with several points, represented as (x, y)
|
596 |
+
# in image w & h scale
|
597 |
+
|
598 |
+
for idx, splited_track in enumerate(resized_all_points):
|
599 |
+
if len(splited_track) == 0:
|
600 |
+
warnings.warn("running without point trajectory control")
|
601 |
+
continue
|
602 |
+
|
603 |
+
if len(splited_track) == 1: # stationary point
|
604 |
+
displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
|
605 |
+
splited_track = tuple([splited_track[0], displacement_point])
|
606 |
+
# interpolate the track
|
607 |
+
splited_track = interpolate_trajectory(splited_track, MODEL_LENGTH)
|
608 |
+
splited_track = splited_track[:MODEL_LENGTH]
|
609 |
+
resized_all_points[idx] = splited_track
|
610 |
+
|
611 |
+
pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
|
612 |
+
|
613 |
+
vis_images = get_vis_image(
|
614 |
+
target_size=(HEIGHT, WIDTH),
|
615 |
+
points=pred_tracks,
|
616 |
+
num_frames=MODEL_LENGTH,
|
617 |
+
)
|
618 |
+
|
619 |
+
if len(pred_tracks.shape) != 3:
|
620 |
+
print("pred_tracks.shape", pred_tracks.shape)
|
621 |
+
with_control = False
|
622 |
+
controlnet_cond_scale = 0.0
|
623 |
+
else:
|
624 |
+
with_control = True
|
625 |
+
pred_tracks = pred_tracks.permute(1, 0, 2).to(device, dtype) # (num_frames, num_points, 2)
|
626 |
+
|
627 |
+
point_embedding = None
|
628 |
+
video_frames = pipe(
|
629 |
+
image,
|
630 |
+
image_end,
|
631 |
+
# trajectory control
|
632 |
+
with_control=with_control,
|
633 |
+
point_tracks=pred_tracks,
|
634 |
+
point_embedding=point_embedding,
|
635 |
+
with_id_feature=False,
|
636 |
+
controlnet_cond_scale=controlnet_cond_scale,
|
637 |
+
# others
|
638 |
+
num_frames=14,
|
639 |
+
width=width,
|
640 |
+
height=height,
|
641 |
+
# decode_chunk_size=8,
|
642 |
+
# generator=generator,
|
643 |
+
motion_bucket_id=motion_bucket_id,
|
644 |
+
fps=7,
|
645 |
+
num_inference_steps=30,
|
646 |
+
# track
|
647 |
+
sift_track_update=sift_track_update,
|
648 |
+
anchor_points_flag=anchor_points_flag,
|
649 |
+
).frames[0]
|
650 |
+
|
651 |
+
vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
|
652 |
+
vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
|
653 |
+
vis_images = [Image.fromarray(img) for img in vis_images]
|
654 |
+
|
655 |
+
# video_frames = [img for sublist in video_frames for img in sublist]
|
656 |
+
val_save_dir = os.path.join(OUTPUT_DIR, "vis_gif.gif")
|
657 |
+
save_gifs_side_by_side(
|
658 |
+
video_frames,
|
659 |
+
vis_images[:MODEL_LENGTH],
|
660 |
+
val_save_dir,
|
661 |
+
target_size=(WIDTH, HEIGHT),
|
662 |
+
duration=110,
|
663 |
+
point_tracks=pred_tracks,
|
664 |
+
)
|
665 |
+
|
666 |
+
return val_save_dir
|
667 |
+
|
668 |
+
|
669 |
if __name__ == "__main__":
|
670 |
|
671 |
+
ensure_dirname(OUTPUT_DIR)
|
|
|
672 |
|
673 |
color_list = []
|
674 |
for i in range(20):
|
|
|
699 |
3. Interpolate the images (according the path) with a click on "Run" button. <br>"""
|
700 |
)
|
701 |
|
|
|
|
|
702 |
first_frame_path = gr.State()
|
703 |
last_frame_path = gr.State()
|
704 |
tracking_points = gr.State([])
|
|
|
824 |
)
|
825 |
|
826 |
run_button.click(
|
827 |
+
fn=run,
|
828 |
inputs=[first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id],
|
829 |
outputs=output_video,
|
830 |
)
|