def3 / scripts /deforum_helpers /src /rife /inference_video.py
ddoc's picture
Upload 188 files
e61bb9a
# thanks to https://github.com/n00mkrad for the inspiration and a bit of code. Also thanks for https://github.com/XmYx for the initial reorganization of this script
import os
from types import SimpleNamespace
import cv2
import torch
import shutil
import numpy as np
from tqdm import tqdm
from torch.nn import functional as F
import warnings
import _thread
from queue import Queue
import time
from .model.pytorch_msssim import ssim_matlab
from deforum_helpers.video_audio_utilities import ffmpeg_stitch_video
from deforum_helpers.general_utils import duplicate_pngs_from_folder
warnings.filterwarnings("ignore")
def run_rife_new_video_infer(
output=None,
model=None,
fp16=False,
UHD=False, # *Will be received as *True* if imgs/vid resolution is 2K or higher*
scale=1.0,
fps=None,
deforum_models_path=None,
raw_output_imgs_path=None,
img_batch_id=None,
ffmpeg_location=None,
audio_track=None,
interp_x_amount=2,
slow_mo_enabled=False,
slow_mo_x_amount=2,
ffmpeg_crf=17,
ffmpeg_preset='veryslow',
keep_imgs=False,
orig_vid_name = None,
srt_path = None):
args = SimpleNamespace()
args.output = output
args.modelDir = model
args.fp16 = fp16
args.UHD = UHD
args.scale = scale
args.fps = fps
args.deforum_models_path = deforum_models_path
args.raw_output_imgs_path = raw_output_imgs_path
args.img_batch_id = img_batch_id
args.ffmpeg_location = ffmpeg_location
args.audio_track = audio_track
args.interp_x_amount = interp_x_amount
args.slow_mo_enabled = slow_mo_enabled
args.slow_mo_x_amount = slow_mo_x_amount
args.ffmpeg_crf = ffmpeg_crf
args.ffmpeg_preset = ffmpeg_preset
args.keep_imgs = keep_imgs
args.orig_vid_name = orig_vid_name
if args.UHD and args.scale == 1.0:
args.scale = 0.5
start_time = time.time()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# TODO: Can/ need to handle this? currently it's always False and give errors if True but faster speeds on tensortcore equipped gpus?
if (args.fp16):
torch.set_default_tensor_type(torch.cuda.HalfTensor)
if args.modelDir is not None:
try:
from .rife_new_gen.RIFE_HDv3 import Model
except ImportError as e:
raise ValueError(f"{args.modelDir} could not be found. Please contact deforum support {e}")
except Exception as e:
raise ValueError(f"An error occured while trying to import {args.modelDir}: {e}")
else:
print("Got a request to frame-interpolate but no valid frame interpolation engine value provided. Doing... nothing")
return
model = Model()
if not hasattr(model, 'version'):
model.version = 0
model.load_model(args.modelDir, -1, deforum_models_path)
model.eval()
model.device()
print(f"{args.modelDir}.pkl model successfully loaded into memory")
print("Interpolation progress (it's OK if it finishes before 100%):")
interpolated_path = os.path.join(args.raw_output_imgs_path, 'interpolated_frames_rife')
# set custom name depending on if we interpolate after a run, or interpolate a video (related/unrelated to deforum, we don't know) directly from within the RIFE tab
if args.orig_vid_name is not None: # interpolating a video (deforum or unrelated)
custom_interp_path = "{}_{}".format(interpolated_path, args.orig_vid_name)
else: # interpolating after a deforum run:
custom_interp_path = "{}_{}".format(interpolated_path, args.img_batch_id)
# In this folder we temporarily keep the original frames (converted/ copy-pasted and img format depends on scenario)
# the convertion case is done to avert a problem with 24 and 32 mixed outputs from the same animation run
temp_convert_raw_png_path = os.path.join(args.raw_output_imgs_path, "tmp_rife_folder")
duplicate_pngs_from_folder(args.raw_output_imgs_path, temp_convert_raw_png_path, args.img_batch_id, args.orig_vid_name)
videogen = []
for f in os.listdir(temp_convert_raw_png_path):
# double check for old _depth_ files, not really needed probably but keeping it for now
if '_depth_' not in f:
videogen.append(f)
tot_frame = len(videogen)
videogen.sort(key= lambda x:int(x.split('.')[0]))
img_path = os.path.join(temp_convert_raw_png_path, videogen[0])
lastframe = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
videogen = videogen[1:]
h, w, _ = lastframe.shape
vid_out = None
if not os.path.exists(custom_interp_path):
os.mkdir(custom_interp_path)
tmp = max(128, int(128 / args.scale))
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
padding = (0, pw - w, 0, ph - h)
pbar = tqdm(total=tot_frame)
write_buffer = Queue(maxsize=500)
read_buffer = Queue(maxsize=500)
_thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen, temp_convert_raw_png_path))
_thread.start_new_thread(clear_write_buffer, (args, write_buffer, custom_interp_path))
I1 = torch.from_numpy(np.transpose(lastframe, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
I1 = pad_image(I1, args.fp16, padding)
temp = None # save lastframe when processing static frame
while True:
if temp is not None:
frame = temp
temp = None
else:
frame = read_buffer.get()
if frame is None:
break
I0 = I1
I1 = torch.from_numpy(np.transpose(frame, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
I1 = pad_image(I1, args.fp16, padding)
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
break_flag = False
if ssim > 0.996:
frame = read_buffer.get() # read a new frame
if frame is None:
break_flag = True
frame = lastframe
else:
temp = frame
I1 = torch.from_numpy(np.transpose(frame, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
I1 = pad_image(I1, args.fp16, padding)
I1 = model.inference(I0, I1, args.scale)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
if ssim < 0.2:
output = []
for i in range(args.interp_x_amount - 1):
output.append(I0)
else:
output = make_inference(model, I0, I1, args.interp_x_amount - 1, scale)
write_buffer.put(lastframe)
for mid in output:
mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
write_buffer.put(mid[:h, :w])
pbar.update(1)
lastframe = frame
if break_flag:
break
write_buffer.put(lastframe)
while (not write_buffer.empty()):
time.sleep(0.1)
pbar.close()
shutil.rmtree(temp_convert_raw_png_path)
print(f"Interpolation \033[0;32mdone\033[0m in {time.time()-start_time:.2f} seconds!")
# stitch video from interpolated frames, and add audio if needed
try:
print (f"*Passing interpolated frames to ffmpeg...*")
vid_out_path = stitch_video(args.img_batch_id, args.fps, custom_interp_path, args.audio_track, args.ffmpeg_location, args.interp_x_amount, args.slow_mo_enabled, args.slow_mo_x_amount, args.ffmpeg_crf, args.ffmpeg_preset, args.keep_imgs, args.orig_vid_name, srt_path=srt_path)
# remove folder with raw (non-interpolated) vid input frames in case of input VID and not PNGs
if orig_vid_name is not None:
shutil.rmtree(raw_output_imgs_path)
return vid_out_path
except Exception as e:
print(f'Video stitching gone wrong. *Interpolated frames were saved to HD as backup!*. Actual error: {e}')
def clear_write_buffer(user_args, write_buffer, custom_interp_path):
cnt = 0
while True:
item = write_buffer.get()
if item is None:
break
filename = '{}/{:0>9d}.png'.format(custom_interp_path, cnt)
cv2.imwrite(filename, item[:, :, ::-1])
cnt += 1
def build_read_buffer(user_args, read_buffer, videogen, temp_convert_raw_png_path):
for frame in videogen:
if not temp_convert_raw_png_path is None:
img_path = os.path.join(temp_convert_raw_png_path, frame)
frame = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
read_buffer.put(frame)
read_buffer.put(None)
def make_inference(model, I0, I1, n, scale):
if model.version >= 3.9:
res = []
for i in range(n):
res.append(model.inference(I0, I1, (i + 1) * 1. / (n + 1), scale))
return res
else:
middle = model.inference(I0, I1, scale)
if n == 1:
return [middle]
first_half = make_inference(model, I0, middle, n=n // 2, scale=scale)
second_half = make_inference(model, middle, I1, n=n // 2, scale=scale)
if n % 2:
return [*first_half, middle, *second_half]
else:
return [*first_half, *second_half]
def pad_image(img, fp16, padding):
if (fp16):
return F.pad(img, padding).half()
else:
return F.pad(img, padding)
# TODO: move to fream_interpolation and add FILM to it!
def stitch_video(img_batch_id, fps, img_folder_path, audio_path, ffmpeg_location, interp_x_amount, slow_mo_enabled, slow_mo_x_amount, f_crf, f_preset, keep_imgs, orig_vid_name, srt_path=None):
parent_folder = os.path.dirname(img_folder_path)
grandparent_folder = os.path.dirname(parent_folder)
if orig_vid_name is not None:
mp4_path = os.path.join(grandparent_folder, str(orig_vid_name) +'_RIFE_' + 'x' + str(interp_x_amount))
else:
mp4_path = os.path.join(parent_folder, str(img_batch_id) +'_RIFE_' + 'x' + str(interp_x_amount))
if slow_mo_enabled:
mp4_path = mp4_path + '_slomo_x' + str(slow_mo_x_amount)
mp4_path = mp4_path + '.mp4'
t = os.path.join(img_folder_path, "%09d.png")
add_soundtrack = 'None'
if not audio_path is None:
add_soundtrack = 'File'
exception_raised = False
try:
ffmpeg_stitch_video(ffmpeg_location=ffmpeg_location, fps=fps, outmp4_path=mp4_path, stitch_from_frame=0, stitch_to_frame=1000000, imgs_path=t, add_soundtrack=add_soundtrack, audio_path=audio_path, crf=f_crf, preset=f_preset, srt_path=srt_path)
except Exception as e:
exception_raised = True
print(f"An error occurred while stitching the video: {e}")
if not exception_raised and not keep_imgs:
shutil.rmtree(img_folder_path)
if (keep_imgs and orig_vid_name is not None) or (orig_vid_name is not None and exception_raised is True):
shutil.move(img_folder_path, grandparent_folder)
return mp4_path