Spaces:
Running
Running
# ************************************************************************* | |
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- | |
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- | |
# ytedance Inc.. | |
# ************************************************************************* | |
import os | |
import argparse | |
import numpy as np | |
# torch | |
import torch | |
from ema_pytorch import EMA | |
from einops import rearrange | |
import cv2 | |
# utils | |
from utils.utils import set_seed, count_param, print_peak_memory | |
# model | |
import imageio | |
from model_lib.ControlNet.cldm.model import create_model | |
import copy | |
import glob | |
import imageio | |
from skimage.transform import resize | |
from skimage import img_as_ubyte | |
import face_alignment | |
import sys | |
from decord import VideoReader | |
from decord import cpu, gpu | |
TORCH_VERSION = torch.__version__.split(".")[0] | |
FP16_DTYPE = torch.float16 | |
print(f"TORCH_VERSION={TORCH_VERSION} FP16_DTYPE={FP16_DTYPE}") | |
def extract_local_feature_from_single_img(img, fa, remove_local=False, real_tocrop=None, target_res = 512): | |
device = img.device | |
pred = img.permute([1, 2, 0]).detach().cpu().numpy() | |
pred_lmks = img_as_ubyte(resize(pred, (256, 256))) | |
try: | |
lmks = fa.get_landmarks_from_image(pred_lmks, return_landmark_score=False)[0] | |
except: | |
print ('undetected faces!!') | |
if real_tocrop is None: | |
return torch.zeros_like(img) * 2 - 1., [196,196,320,320] | |
return torch.zeros_like(img), [196,196,320,320] | |
halfedge = 32 | |
left_eye_center = (np.clip(np.round(np.mean(lmks[43:48], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32) | |
right_eye_center = (np.clip(np.round(np.mean(lmks[37:42], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32) | |
mouth_center = (np.clip(np.round(np.mean(lmks[49:68], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32) | |
if real_tocrop is not None: | |
pred = real_tocrop.permute([1, 2, 0]).detach().cpu().numpy() | |
half_size = target_res // 8 #64 | |
if remove_local: | |
local_viz = pred | |
local_viz[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size] = 0 | |
local_viz[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size] = 0 | |
local_viz[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0] + half_size] = 0 | |
else: | |
local_viz = np.zeros_like(pred) | |
local_viz[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size] = pred[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size] | |
local_viz[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size] = pred[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size] | |
local_viz[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0] + half_size] = pred[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0] + half_size] | |
local_viz = torch.from_numpy(local_viz).to(device) | |
local_viz = local_viz.permute([2, 0, 1]) | |
if real_tocrop is None: | |
local_viz = local_viz * 2 - 1. | |
return local_viz | |
def find_best_frame_byheadpose_fa(source_image, driving_video, fa): | |
input = img_as_ubyte(resize(source_image, (256, 256))) | |
try: | |
src_pose_array = fa.get_landmarks_from_image(input, return_landmark_score=False)[0] | |
except: | |
print ('undetected faces in the source image!!') | |
src_pose_array = np.zeros((68,2)) | |
if len(src_pose_array) == 0: | |
return 0 | |
min_diff = 1e8 | |
best_frame = 0 | |
for i in range(len(driving_video)): | |
frame = img_as_ubyte(resize(driving_video[i], (256, 256))) | |
try: | |
drv_pose_array = fa.get_landmarks_from_image(frame, return_landmark_score=False)[0] | |
except: | |
print ('undetected faces in the %d-th driving image!!'%i) | |
drv_pose_array = np.zeros((68,2)) | |
diff = np.sum(np.abs(np.array(src_pose_array)-np.array(drv_pose_array))) | |
if diff < min_diff: | |
best_frame = i | |
min_diff = diff | |
return best_frame | |
def adjust_driving_video_to_src_image(source_image, driving_video, fa, nm_res, nmd_res, best_frame=-1): | |
if best_frame == -2: | |
return [resize(frame, (nm_res, nm_res)) for frame in driving_video], [resize(frame, (nmd_res, nmd_res)) for frame in driving_video] | |
src = img_as_ubyte(resize(source_image[..., :3], (256, 256))) | |
if best_frame >= len(source_image): | |
raise ValueError( | |
f"please specify one frame in driving video of which the pose match best with the pose of source image" | |
) | |
if best_frame < 0: | |
best_frame = find_best_frame_byheadpose_fa(src, driving_video, fa) | |
print ('Best Frame: %d' % best_frame) | |
driving = img_as_ubyte(resize(driving_video[best_frame], (256, 256))) | |
src_lmks = fa.get_landmarks_from_image(src, return_landmark_score=False) | |
drv_lmks = fa.get_landmarks_from_image(driving, return_landmark_score=False) | |
if (src_lmks is None) or (drv_lmks is None): | |
return [resize(frame, (nm_res, nm_res)) for frame in driving_video], [resize(frame, (nmd_res, nmd_res)) for frame in driving_video] | |
src_lmks = src_lmks[0] | |
drv_lmks = drv_lmks[0] | |
src_centers = np.mean(src_lmks, axis=0) | |
drv_centers = np.mean(drv_lmks, axis=0) | |
edge_src = (np.max(src_lmks, axis=0) - np.min(src_lmks, axis=0))*0.5 | |
edge_drv = (np.max(drv_lmks, axis=0) - np.min(drv_lmks, axis=0))*0.5 | |
#matching three points | |
src_point=np.array([[src_centers[0]-edge_src[0],src_centers[1]-edge_src[1]],[src_centers[0]+edge_src[0],src_centers[1]-edge_src[1]],[src_centers[0]-edge_src[0],src_centers[1]+edge_src[1]],[src_centers[0]+edge_src[0],src_centers[1]+edge_src[1]]]).astype(np.float32) | |
dst_point=np.array([[drv_centers[0]-edge_drv[0],drv_centers[1]-edge_drv[1]],[drv_centers[0]+edge_drv[0],drv_centers[1]-edge_drv[1]],[drv_centers[0]-edge_drv[0],drv_centers[1]+edge_drv[1]],[drv_centers[0]+edge_drv[0],drv_centers[1]+edge_drv[1]]]).astype(np.float32) | |
adjusted_driving_video = [] | |
adjusted_driving_video_hd = [] | |
for frame in driving_video: | |
frame_ld = resize(frame, (nm_res, nm_res)) | |
frame_hd = resize(frame, (nmd_res, nmd_res)) | |
zoomed=cv2.warpAffine(frame_ld, cv2.getAffineTransform(dst_point[:3], src_point[:3]), (nm_res, nm_res)) | |
zoomed_hd=cv2.warpAffine(frame_hd, cv2.getAffineTransform(dst_point[:3] * 2, src_point[:3] * 2), (nmd_res, nmd_res)) | |
adjusted_driving_video.append(zoomed) | |
adjusted_driving_video_hd.append(zoomed_hd) | |
return adjusted_driving_video, adjusted_driving_video_hd | |
def x_portrait_data_prep(source_image_path, driving_video_path, device, best_frame_id=0, start_idx = 0, num_frames=0, skip=1, output_local=False, more_source_image_pattern="", target_resolution = 512): | |
source_image = imageio.imread(source_image_path) | |
if '.mp4' in driving_video_path: | |
reader = imageio.get_reader(driving_video_path) | |
fps = reader.get_meta_data()['fps'] | |
driving_video = [] | |
try: | |
for im in reader: | |
driving_video.append(im) | |
except RuntimeError: | |
pass | |
reader.close() | |
else: | |
driving_video = [imageio.imread(driving_video_path)[...,:3]] | |
fps = 1 | |
nmd_res = target_resolution | |
nm_res = 256 | |
source_image_hd = resize(source_image, (nmd_res, nmd_res))[..., :3] | |
if more_source_image_pattern: | |
more_source_paths = glob.glob(more_source_image_pattern) | |
more_sources_hd = [] | |
for more_source_path in more_source_paths: | |
more_source_image = imageio.imread(more_source_path) | |
more_source_image_hd = resize(more_source_image, (nmd_res, nmd_res))[..., :3] | |
more_source_hd = torch.tensor(more_source_image_hd[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) | |
more_source_hd = more_source_hd.to(device) | |
more_sources_hd.append(more_source_hd) | |
more_sources_hd = torch.stack(more_sources_hd, dim = 1) | |
else: | |
more_sources_hd = None | |
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=True, device='cuda') | |
driving_video, driving_video_hd = adjust_driving_video_to_src_image(source_image, driving_video, fa, nm_res, nmd_res, best_frame_id) | |
if num_frames == 0: | |
end_idx = len(driving_video) | |
else: | |
num_frames = min(len(driving_video), num_frames) | |
end_idx = start_idx + num_frames * skip | |
driving_video = driving_video[start_idx:end_idx][::skip] | |
driving_video_hd = driving_video_hd[start_idx:end_idx][::skip] | |
num_frames = len(driving_video) | |
with torch.no_grad(): | |
real_source_hd = torch.tensor(source_image_hd[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) | |
real_source_hd = real_source_hd.to(device) | |
driving_hd = torch.tensor(np.array(driving_video_hd).astype(np.float32)).permute(0, 3, 1, 2).to(device) | |
local_features = [] | |
raw_drivings=[] | |
for frame_idx in range(0, num_frames): | |
raw_drivings.append(driving_hd[frame_idx:frame_idx+1] * 2 - 1.) | |
if output_local: | |
local_feature_img = extract_local_feature_from_single_img(driving_hd[frame_idx], fa,target_res=nmd_res) | |
local_features.append(local_feature_img) | |
batch_data = {} | |
batch_data['fps'] = fps | |
real_source_hd = real_source_hd * 2 - 1 | |
batch_data['sources'] = real_source_hd[:, None, :, :, :].repeat([num_frames, 1, 1, 1, 1]) | |
if more_sources_hd is not None: | |
more_sources_hd = more_sources_hd * 2 - 1 | |
batch_data['more_sources'] = more_sources_hd.repeat([num_frames, 1, 1, 1, 1]) | |
raw_drivings = torch.stack(raw_drivings, dim = 0) | |
batch_data['conditions'] = raw_drivings | |
if output_local: | |
batch_data['local'] = torch.stack(local_features, dim = 0) | |
return batch_data | |
# You can now use the modified state_dict without the deleted keys | |
def load_state_dict(model, ckpt_path, reinit_hint_block=False, strict=True, map_location="cpu"): | |
print(f"Loading model state dict from {ckpt_path} ...") | |
state_dict = torch.load(ckpt_path, map_location=map_location) | |
state_dict = state_dict.get('state_dict', state_dict) | |
if reinit_hint_block: | |
print("Ignoring hint block parameters from checkpoint!") | |
for k in list(state_dict.keys()): | |
if k.startswith("control_model.input_hint_block"): | |
state_dict.pop(k) | |
model.load_state_dict(state_dict, strict=strict) | |
del state_dict | |
def get_cond_control(args, batch_data, control_type, device, start, end, model=None, batch_size=None, train=True, key=0): | |
control_type = copy.deepcopy(control_type) | |
vae_bs = 16 | |
if control_type == "appearance_pose_local_mm": | |
src = batch_data['sources'][start:end, key].cuda() | |
c_cat_list = batch_data['conditions'][start:end].cuda() | |
cond_image = [] | |
for k in range(0, end-start, vae_bs): | |
cond_image.append(model.get_first_stage_encoding(model.encode_first_stage(src[k:k+vae_bs]))) | |
cond_image = torch.concat(cond_image, dim=0) | |
cond_img_cat = cond_image | |
p_local = batch_data['local'][start:end].cuda() | |
print ('Total frames:{}'.format(cond_img_cat.shape)) | |
more_cond_imgs = [] | |
if 'more_sources' in batch_data: | |
num_additional_cond_imgs = batch_data['more_sources'].shape[1] | |
for i in range(num_additional_cond_imgs): | |
m_cond_img = batch_data['more_sources'][start:end, i] | |
m_cond_img = model.get_first_stage_encoding(model.encode_first_stage(m_cond_img)) | |
more_cond_imgs.append([m_cond_img.to(device)]) | |
return [cond_img_cat.to(device), c_cat_list, p_local, more_cond_imgs] | |
else: | |
raise NotImplementedError(f"cond_type={control_type} not supported!") | |
def visualize_mm(args, name, batch_data, infer_model, nSample, local_image_dir, num_mix=4, preset_output_name=''): | |
driving_video_name = os.path.basename(batch_data['video_name']).split('.')[0] | |
source_name = os.path.basename(batch_data['source_name']).split('.')[0] | |
if not os.path.exists(local_image_dir): | |
os.mkdir(local_image_dir) | |
uc_scale = args.uc_scale | |
if preset_output_name: | |
preset_output_name = preset_output_name.split('.')[0]+'.mp4' | |
output_path = f"{local_image_dir}/{preset_output_name}" | |
else: | |
output_path = f"{local_image_dir}/{name}_{args.control_type}_uc{uc_scale}_{source_name}_by_{driving_video_name}_mix{num_mix}.mp4" | |
infer_model.eval() | |
gene_img_list = [] | |
_, _, ch, h, w = batch_data['sources'].shape | |
vae_bs = 16 | |
if args.initial_facevid2vid_results: | |
facevid2vid = [] | |
facevid2vid_results = VideoReader(args.initial_facevid2vid_results, ctx=cpu(0)) | |
for frame_id in range(len(facevid2vid_results)): | |
frame = cv2.resize(facevid2vid_results[frame_id].asnumpy(),(512,512)) / 255 | |
facevid2vid.append(torch.from_numpy(frame * 2 - 1).permute(2,0,1)) | |
cond = torch.stack(facevid2vid)[:nSample].float().to(args.device) | |
pre_noise=[] | |
for i in range(0, nSample, vae_bs): | |
pre_noise.append(infer_model.get_first_stage_encoding(infer_model.encode_first_stage(cond[i:i+vae_bs]))) | |
pre_noise = torch.cat(pre_noise, dim=0) | |
pre_noise = infer_model.q_sample(x_start = pre_noise, t = torch.tensor([999]).to(pre_noise.device)) | |
else: | |
cond = batch_data['sources'][:nSample].reshape([-1, ch, h, w]) | |
pre_noise=[] | |
for i in range(0, nSample, vae_bs): | |
pre_noise.append(infer_model.get_first_stage_encoding(infer_model.encode_first_stage(cond[i:i+vae_bs]))) | |
pre_noise = torch.cat(pre_noise, dim=0) | |
pre_noise = infer_model.q_sample(x_start = pre_noise, t = torch.tensor([999]).to(pre_noise.device)) | |
text = ["" for _ in range(nSample)] | |
all_c_cat = get_cond_control(args, batch_data, args.control_type, args.device, start=0, end=nSample, model=infer_model, train=False) | |
cond_img_cat = [all_c_cat[0]] | |
pose_cond_list = [rearrange(all_c_cat[1], "b f c h w -> (b f) c h w")] | |
local_pose_cond_list = [all_c_cat[2]] | |
c_cross = infer_model.get_learned_conditioning(text)[:nSample] | |
uc_cross = infer_model.get_unconditional_conditioning(nSample) | |
c = {"c_crossattn": [c_cross], "image_control": cond_img_cat} | |
if "appearance_pose" in args.control_type: | |
c['c_concat'] = pose_cond_list | |
if "appearance_pose_local" in args.control_type: | |
c["local_c_concat"] = local_pose_cond_list | |
if len(all_c_cat) > 3 and len(all_c_cat[3]) > 0: | |
c['more_image_control'] = all_c_cat[3] | |
if args.control_mode == "controlnet_important": | |
uc = {"c_crossattn": [uc_cross]} | |
else: | |
uc = {"c_crossattn": [uc_cross], "image_control":cond_img_cat} | |
if "appearance_pose" in args.control_type: | |
uc['c_concat'] = [torch.zeros_like(pose_cond_list[0])] | |
if "appearance_pose_local" in args.control_type: | |
uc["local_c_concat"] = [torch.zeros_like(local_pose_cond_list[0])] | |
if len(all_c_cat) > 3 and len(all_c_cat[3]) > 0: | |
uc['more_image_control'] = all_c_cat[3] | |
if args.wonoise: | |
c['wonoise'] = True | |
uc['wonoise'] = True | |
else: | |
c['wonoise'] = False | |
uc['wonoise'] = False | |
noise = pre_noise.to(c_cross.device) | |
with torch.cuda.amp.autocast(enabled=args.use_fp16, dtype=FP16_DTYPE): | |
infer_model.to(args.device) | |
infer_model.eval() | |
gene_img, _ = infer_model.sample_log(cond=c, | |
batch_size=args.num_drivings, ddim=True, | |
ddim_steps=args.ddim_steps, eta=args.eta, | |
unconditional_guidance_scale=uc_scale, | |
unconditional_conditioning=uc, | |
inpaint=None, | |
x_T=noise, | |
num_overlap=num_mix, | |
) | |
for i in range(0, nSample, vae_bs): | |
gene_img_part = infer_model.decode_first_stage( gene_img[i:i+vae_bs] ) | |
gene_img_list.append(gene_img_part.float().clamp(-1, 1).cpu()) | |
_, c, h, w = gene_img_list[0].shape | |
cond_image = batch_data["conditions"].reshape([-1,c,h,w])[:nSample].cpu() | |
l_cond_image = batch_data["local"].reshape([-1,c,h,w])[:nSample].cpu() | |
orig_image = batch_data["sources"][:nSample, 0].cpu() | |
output_img = torch.cat(gene_img_list + [cond_image.cpu()]+[l_cond_image.cpu()]+[orig_image.cpu()]).float().clamp(-1,1).add(1).mul(0.5) | |
num_cols = 4 | |
output_img = output_img.reshape([num_cols, 1, nSample, c, h, w]).permute([1, 0, 2, 3, 4,5]) | |
output_img = output_img.permute([2, 3, 0, 4, 1, 5]).reshape([-1, c, h, num_cols * w]) | |
output_img = torch.permute(output_img, [0, 2, 3, 1]) | |
output_img = output_img.data.cpu().numpy() | |
output_img = img_as_ubyte(output_img) | |
imageio.mimsave(output_path, output_img[:,:,:512], fps=batch_data['fps'], quality=10, pixelformat='yuv420p', codec='libx264') | |
def main(args): | |
# ****************************** | |
# initialize training | |
# ****************************** | |
args.world_size = 1 | |
args.local_rank = 0 | |
args.rank = 0 | |
args.device = torch.device("cuda", args.local_rank) | |
# set seed for reproducibility | |
set_seed(args.seed) | |
# ****************************** | |
# create model | |
# ****************************** | |
model = create_model(args.model_config).cpu() | |
model.sd_locked = args.sd_locked | |
model.only_mid_control = args.only_mid_control | |
model.to(args.local_rank) | |
if not os.path.exists(args.output_dir): | |
os.makedirs(args.output_dir) | |
if args.local_rank == 0: | |
print('Total base parameters {:.02f}M'.format(count_param([model]))) | |
if args.ema_rate is not None and args.ema_rate > 0 and args.rank == 0: | |
print(f"Creating EMA model at ema_rate={args.ema_rate}") | |
model_ema = EMA(model, beta=args.ema_rate, update_after_step=0, update_every=1) | |
else: | |
model_ema = None | |
# ****************************** | |
# load pre-trained models | |
# ****************************** | |
if args.resume_dir is not None: | |
if args.local_rank == 0: | |
load_state_dict(model, args.resume_dir, strict=False) | |
else: | |
print('please privide the correct resume_dir!') | |
exit() | |
# ****************************** | |
# create DDP model | |
# ****************************** | |
if args.compile and TORCH_VERSION == "2": | |
model = torch.compile(model) | |
torch.cuda.set_device(args.local_rank) | |
print_peak_memory("Max memory allocated after creating DDP", args.local_rank) | |
infer_model = model.module if hasattr(model, "module") else model | |
with torch.no_grad(): | |
driving_videos = glob.glob(args.driving_video) | |
for driving_video in driving_videos: | |
print ('working on {}'.format(os.path.basename(driving_video))) | |
infer_batch_data = x_portrait_data_prep(args.source_image, driving_video, args.device, args.best_frame, start_idx = args.start_idx, num_frames = args.out_frames, skip=args.skip, output_local=True) | |
infer_batch_data['video_name'] = os.path.basename(driving_video) | |
infer_batch_data['source_name'] = args.source_image | |
nSample = infer_batch_data['sources'].shape[0] | |
visualize_mm(args, "inference", infer_batch_data, infer_model, nSample=nSample, local_image_dir=args.output_dir, num_mix=args.num_mix) | |
if __name__ == "__main__": | |
str2bool = lambda arg: bool(int(arg)) | |
parser = argparse.ArgumentParser(description='Control Net training') | |
## Model | |
parser.add_argument('--model_config', type=str, default="model_lib/ControlNet/models/cldm_v15_video_appearance.yaml", | |
help="The path of model config file") | |
parser.add_argument('--reinit_hint_block', action='store_true', default=False, | |
help="Re-initialize hint blocks for channel mis-match") | |
parser.add_argument('--sd_locked', type =str2bool, default=True, | |
help='Freeze parameters in original stable-diffusion decoder') | |
parser.add_argument('--only_mid_control', type =str2bool, default=False, | |
help='Only control middle blocks') | |
parser.add_argument('--control_type', type=str, default="appearance_pose_local_mm", | |
help='The type of conditioning') | |
parser.add_argument("--control_mode", type=str, default="controlnet_important", | |
help="Set controlnet is more important or balance.") | |
parser.add_argument('--wonoise', action='store_false', default=True, | |
help='Use with referenceonly, remove adding noise on reference image') | |
## Training | |
parser.add_argument("--local_rank", type=int, default=0) | |
parser.add_argument("--world_size", type=int, default=1) | |
parser.add_argument('--seed', type=int, default=42, | |
help='random seed for initialization') | |
parser.add_argument('--use_fp16', action='store_false', default=True, | |
help='Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit') | |
parser.add_argument('--compile', type=str2bool, default=False, | |
help='compile model (for torch 2)') | |
parser.add_argument('--eta', type = float, default = 0.0, | |
help='eta during DDIM Sampling') | |
parser.add_argument('--ema_rate', type = float, default = 0, | |
help='rate for ema') | |
## inference | |
parser.add_argument("--initial_facevid2vid_results", type=str, default=None, | |
help="facevid2vid results for noise initialization") | |
parser.add_argument('--ddim_steps', type = int, default = 1, | |
help='denoising steps') | |
parser.add_argument('--uc_scale', type = int, default = 5, | |
help='cfg') | |
parser.add_argument("--num_drivings", type = int, default = 16, | |
help="Number of driving images in a single sequence of video.") | |
parser.add_argument("--output_dir", type=str, default=None, required=True, | |
help="The output directory where the model predictions and checkpoints will be written.") | |
parser.add_argument("--resume_dir", type=str, default=None, | |
help="The resume directory where the model checkpoints will be loaded.") | |
parser.add_argument("--source_image", type=str, default="", | |
help="The source image for neural motion.") | |
parser.add_argument("--more_source_image_pattern", type=str, default="", | |
help="The source image for neural motion.") | |
parser.add_argument("--driving_video", type=str, default="", | |
help="The source image mask for neural motion.") | |
parser.add_argument('--best_frame', type=int, default=0, | |
help='best matching frame index') | |
parser.add_argument('--start_idx', type=int, default=0, | |
help='starting frame index') | |
parser.add_argument('--skip', type=int, default=1, | |
help='skip frame') | |
parser.add_argument('--num_mix', type=int, default=4, | |
help='num overlapping frames') | |
parser.add_argument('--out_frames', type=int, default=0, | |
help='num frames') | |
args = parser.parse_args() | |
main(args) | |