File size: 4,794 Bytes
e61bb9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import gc
from glob import glob
import bisect
from tqdm import tqdm
import torch
import numpy as np
import cv2
from .film_util import load_image
import time
from types import SimpleNamespace
from modules.shared import cmd_opts
import warnings
warnings.filterwarnings("ignore")
def run_film_interp_infer(
model_path = None,
input_folder = None,
save_folder = None,
inter_frames = None):
args = SimpleNamespace()
args.model_path = model_path
args.input_folder = input_folder
args.save_folder = save_folder
args.inter_frames = inter_frames
# Check if the folder exists
if not os.path.exists(args.input_folder):
print(f"Error: Folder '{args.input_folder}' does not exist.")
return
# Check if the folder contains any PNG or JPEG images
if not any([f.endswith(".png") or f.endswith(".jpg") for f in os.listdir(args.input_folder)]):
print(f"Error: Folder '{args.input_folder}' does not contain any PNG or JPEG images.")
return
start_time = time.time() # Timer START
# Sort Jpg/Png images by name
image_paths = sorted(glob(os.path.join(args.input_folder, "*.[jJ][pP][gG]")) + glob(os.path.join(args.input_folder, "*.[pP][nN][gG]")))
print(f"Total frames to FILM-interpolate: {len(image_paths)}. Total frame-pairs: {len(image_paths)-1}.")
model = torch.jit.load(args.model_path, map_location='cpu')
# half precision the model if user didn't pass --no-half/ --precision full cmd arg flags
if not cmd_opts.no_half:
model = model.half()
model = model.cuda()
model.eval()
for i in tqdm(range(len(image_paths) - 1), desc='FILM progress'):
img1 = image_paths[i]
img2 = image_paths[i+1]
img_batch_1, crop_region_1 = load_image(img1)
img_batch_2, crop_region_2 = load_image(img2)
img_batch_1 = torch.from_numpy(img_batch_1).permute(0, 3, 1, 2)
img_batch_2 = torch.from_numpy(img_batch_2).permute(0, 3, 1, 2)
save_path = os.path.join(args.save_folder, f"{i}_to_{i+1}.jpg")
results = [
img_batch_1,
img_batch_2
]
idxes = [0, inter_frames + 1]
remains = list(range(1, inter_frames + 1))
splits = torch.linspace(0, 1, inter_frames + 2)
inner_loop_progress = tqdm(range(len(remains)), leave=False, disable=True)
for _ in inner_loop_progress:
starts = splits[idxes[:-1]]
ends = splits[idxes[1:]]
distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
matrix = torch.argmin(distances).item()
start_i, step = np.unravel_index(matrix, distances.shape)
end_i = start_i + 1
x0 = results[start_i]
x1 = results[end_i]
x0 = x0.half()
x1 = x1.half()
x0 = x0.cuda()
x1 = x1.cuda()
dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
with torch.no_grad():
prediction = model(x0, x1, dt)
insert_position = bisect.bisect_left(idxes, remains[step])
idxes.insert(insert_position, remains[step])
results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
inner_loop_progress.update(1)
del remains[step]
inner_loop_progress.close()
# create output folder for interoplated imgs to live in
os.makedirs(args.save_folder, exist_ok=True)
y1, x1, y2, x2 = crop_region_1
frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy()[y1:y2, x1:x2].copy() for tensor in results]
existing_files = os.listdir(args.save_folder)
if len(existing_files) > 0:
existing_numbers = [int(file.split("_")[1].split(".")[0]) for file in existing_files]
next_number = max(existing_numbers) + 1
else:
next_number = 0
outer_loop_count = i
for i, frame in enumerate(frames):
frame_path = os.path.join(args.save_folder, f"frame_{next_number:09d}.png")
# last pair, save all frames including the last one
if len(image_paths) - 2 == outer_loop_count:
cv2.imwrite(frame_path, frame)
else: # not last pair, don't save the last frame
if not i == len(frames) - 1:
cv2.imwrite(frame_path, frame)
next_number += 1
# remove FILM model from memory
if model is not None:
del model
torch.cuda.empty_cache()
gc.collect()
print(f"Interpolation \033[0;32mdone\033[0m in {time.time()-start_time:.2f} seconds!") |