|
import yaml |
|
import os |
|
from torch.hub import download_url_to_file, get_dir |
|
from urllib.parse import urlparse |
|
import torch |
|
import typing |
|
import traceback |
|
import einops |
|
import gc |
|
import torchvision.transforms.functional as transform |
|
from comfy.model_management import soft_empty_cache, get_torch_device |
|
import numpy as np |
|
|
|
BASE_MODEL_DOWNLOAD_URLS = [ |
|
"https://github.com/styler00dollar/VSGAN-tensorrt-docker/releases/download/models/", |
|
"https://github.com/Fannovel16/ComfyUI-Frame-Interpolation/releases/download/models/", |
|
"https://github.com/dajes/frame-interpolation-pytorch/releases/download/v1.0.0/" |
|
] |
|
|
|
config_path = os.path.join(os.path.dirname(__file__), "./config.yaml") |
|
if os.path.exists(config_path): |
|
config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) |
|
else: |
|
raise Exception("config.yaml file is neccessary, plz recreate the config file by downloading it from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation") |
|
DEVICE = get_torch_device() |
|
|
|
class InterpolationStateList(): |
|
|
|
def __init__(self, frame_indices: typing.List[int], is_skip_list: bool): |
|
self.frame_indices = frame_indices |
|
self.is_skip_list = is_skip_list |
|
|
|
def is_frame_skipped(self, frame_index): |
|
is_frame_in_list = frame_index in self.frame_indices |
|
return self.is_skip_list and is_frame_in_list or not self.is_skip_list and not is_frame_in_list |
|
|
|
|
|
class MakeInterpolationStateList: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"frame_indices": ("STRING", {"multiline": True, "default": "1,2,3"}), |
|
"is_skip_list": ("BOOLEAN", {"default": True},), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("INTERPOLATION_STATES",) |
|
FUNCTION = "create_options" |
|
CATEGORY = "ComfyUI-Frame-Interpolation/VFI" |
|
|
|
def create_options(self, frame_indices: str, is_skip_list: bool): |
|
frame_indices_list = [int(item) for item in frame_indices.split(',')] |
|
|
|
interpolation_state_list = InterpolationStateList( |
|
frame_indices=frame_indices_list, |
|
is_skip_list=is_skip_list, |
|
) |
|
return (interpolation_state_list,) |
|
|
|
|
|
def get_ckpt_container_path(model_type): |
|
return os.path.abspath(os.path.join(os.path.dirname(__file__), config["ckpts_path"], model_type)) |
|
|
|
def load_file_from_url(url, model_dir=None, progress=True, file_name=None): |
|
"""Load file form http url, will download models if necessary. |
|
|
|
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py |
|
|
|
Args: |
|
url (str): URL to be downloaded. |
|
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. |
|
Default: None. |
|
progress (bool): Whether to show the download progress. Default: True. |
|
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. |
|
|
|
Returns: |
|
str: The path to the downloaded file. |
|
""" |
|
if model_dir is None: |
|
hub_dir = get_dir() |
|
model_dir = os.path.join(hub_dir, 'checkpoints') |
|
|
|
os.makedirs(model_dir, exist_ok=True) |
|
|
|
parts = urlparse(url) |
|
file_name = os.path.basename(parts.path) |
|
if file_name is not None: |
|
file_name = file_name |
|
cached_file = os.path.abspath(os.path.join(model_dir, file_name)) |
|
if not os.path.exists(cached_file): |
|
print(f'Downloading: "{url}" to {cached_file}\n') |
|
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) |
|
return cached_file |
|
|
|
def load_file_from_github_release(model_type, ckpt_name): |
|
error_strs = [] |
|
for i, base_model_download_url in enumerate(BASE_MODEL_DOWNLOAD_URLS): |
|
try: |
|
return load_file_from_url(base_model_download_url + ckpt_name, get_ckpt_container_path(model_type)) |
|
except Exception: |
|
traceback_str = traceback.format_exc() |
|
if i < len(BASE_MODEL_DOWNLOAD_URLS) - 1: |
|
print("Failed! Trying another endpoint.") |
|
error_strs.append(f"Error when downloading from: {base_model_download_url + ckpt_name}\n\n{traceback_str}") |
|
|
|
error_str = '\n\n'.join(error_strs) |
|
raise Exception(f"Tried all GitHub base urls to download {ckpt_name} but no suceess. Below is the error log:\n\n{error_str}") |
|
|
|
|
|
def load_file_from_direct_url(model_type, url): |
|
return load_file_from_url(url, get_ckpt_container_path(model_type)) |
|
|
|
def preprocess_frames(frames): |
|
return einops.rearrange(frames[..., :3], "n h w c -> n c h w") |
|
|
|
def postprocess_frames(frames): |
|
return einops.rearrange(frames, "n c h w -> n h w c")[..., :3].cpu() |
|
|
|
def assert_batch_size(frames, batch_size=2, vfi_name=None): |
|
subject_verb = "Most VFI models require" if vfi_name is None else f"VFI model {vfi_name} requires" |
|
assert len(frames) >= batch_size, f"{subject_verb} at least {batch_size} frames to work with, only found {frames.shape[0]}. Please check the frame input using PreviewImage." |
|
|
|
def _generic_frame_loop( |
|
frames, |
|
clear_cache_after_n_frames, |
|
multiplier: typing.Union[typing.SupportsInt, typing.List], |
|
return_middle_frame_function, |
|
*return_middle_frame_function_args, |
|
interpolation_states: InterpolationStateList = None, |
|
use_timestep=True, |
|
dtype=torch.float16, |
|
final_logging=True): |
|
|
|
|
|
def non_timestep_inference(frame0, frame1, n): |
|
middle = return_middle_frame_function(frame0, frame1, None, *return_middle_frame_function_args) |
|
if n == 1: |
|
return [middle] |
|
first_half = non_timestep_inference(frame0, middle, n=n//2) |
|
second_half = non_timestep_inference(middle, frame1, n=n//2) |
|
if n%2: |
|
return [*first_half, middle, *second_half] |
|
else: |
|
return [*first_half, *second_half] |
|
|
|
output_frames = torch.zeros(multiplier*frames.shape[0], *frames.shape[1:], dtype=dtype, device="cpu") |
|
out_len = 0 |
|
|
|
number_of_frames_processed_since_last_cleared_cuda_cache = 0 |
|
|
|
for frame_itr in range(len(frames) - 1): |
|
frame0 = frames[frame_itr:frame_itr+1] |
|
output_frames[out_len] = frame0 |
|
out_len += 1 |
|
|
|
frame0 = frame0.to(dtype=torch.float32) |
|
frame1 = frames[frame_itr+1:frame_itr+2].to(dtype=torch.float32) |
|
|
|
if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr): |
|
continue |
|
|
|
|
|
middle_frame_batches = [] |
|
|
|
if use_timestep: |
|
for middle_i in range(1, multiplier): |
|
timestep = middle_i/multiplier |
|
|
|
middle_frame = return_middle_frame_function( |
|
frame0.to(DEVICE), |
|
frame1.to(DEVICE), |
|
timestep, |
|
*return_middle_frame_function_args |
|
).detach().cpu() |
|
middle_frame_batches.append(middle_frame.to(dtype=dtype)) |
|
else: |
|
middle_frames = non_timestep_inference(frame0.to(DEVICE), frame1.to(DEVICE), multiplier - 1) |
|
middle_frame_batches.extend(torch.cat(middle_frames, dim=0).detach().cpu().to(dtype=dtype)) |
|
|
|
|
|
for middle_frame in middle_frame_batches: |
|
output_frames[out_len] = middle_frame |
|
out_len += 1 |
|
|
|
number_of_frames_processed_since_last_cleared_cuda_cache += 1 |
|
|
|
if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames: |
|
print("Comfy-VFI: Clearing cache...", end=' ') |
|
soft_empty_cache() |
|
number_of_frames_processed_since_last_cleared_cuda_cache = 0 |
|
print("Done cache clearing") |
|
|
|
gc.collect() |
|
|
|
if final_logging: |
|
print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}") |
|
|
|
output_frames[out_len] = frames[-1:] |
|
out_len += 1 |
|
|
|
if final_logging: |
|
print("Comfy-VFI: Final clearing cache...", end = ' ') |
|
soft_empty_cache() |
|
if final_logging: |
|
print("Done cache clearing") |
|
return output_frames[:out_len] |
|
|
|
def generic_frame_loop( |
|
model_name, |
|
frames, |
|
clear_cache_after_n_frames, |
|
multiplier: typing.Union[typing.SupportsInt, typing.List], |
|
return_middle_frame_function, |
|
*return_middle_frame_function_args, |
|
interpolation_states: InterpolationStateList = None, |
|
use_timestep=True, |
|
dtype=torch.float32): |
|
|
|
assert_batch_size(frames, vfi_name=model_name.replace('_', ' ').replace('VFI', '')) |
|
if type(multiplier) == int: |
|
return _generic_frame_loop( |
|
frames, |
|
clear_cache_after_n_frames, |
|
multiplier, |
|
return_middle_frame_function, |
|
*return_middle_frame_function_args, |
|
interpolation_states=interpolation_states, |
|
use_timestep=use_timestep, |
|
dtype=dtype |
|
) |
|
if type(multiplier) == list: |
|
multipliers = list(map(int, multiplier)) |
|
multipliers += [2] * (len(frames) - len(multipliers) - 1) |
|
frame_batches = [] |
|
for frame_itr in range(len(frames) - 1): |
|
multiplier = multipliers[frame_itr] |
|
if multiplier == 0: continue |
|
frame_batch = _generic_frame_loop( |
|
frames[frame_itr:frame_itr+2], |
|
clear_cache_after_n_frames, |
|
multiplier, |
|
return_middle_frame_function, |
|
*return_middle_frame_function_args, |
|
interpolation_states=interpolation_states, |
|
use_timestep=use_timestep, |
|
dtype=dtype, |
|
final_logging=False |
|
) |
|
if frame_itr != len(frames) - 2: |
|
frame_batch = frame_batch[:-1] |
|
frame_batches.append(frame_batch) |
|
output_frames = torch.cat(frame_batches) |
|
print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}") |
|
return output_frames |
|
raise NotImplementedError(f"multipiler of {type(multiplier)}") |
|
|
|
class FloatToInt: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"float": ("FLOAT", {"default": 0, 'min': 0, 'step': 0.01}) |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("INT",) |
|
FUNCTION = "convert" |
|
CATEGORY = "ComfyUI-Frame-Interpolation" |
|
|
|
def convert(self, float): |
|
if hasattr(float, "__iter__"): |
|
return (list(map(int, float)),) |
|
return (int(float),) |
|
|
|
""" def generic_4frame_loop( |
|
frames, |
|
clear_cache_after_n_frames, |
|
multiplier: typing.SupportsInt, |
|
return_middle_frame_function, |
|
*return_middle_frame_function_args, |
|
interpolation_states: InterpolationStateList = None, |
|
use_timestep=False): |
|
|
|
if use_timestep: raise NotImplementedError("Timestep 4 frame VFI model") |
|
def non_timestep_inference(frame_0, frame_1, frame_2, frame_3, n): |
|
middle = return_middle_frame_function(frame_0, frame_1, None, *return_middle_frame_function_args) |
|
if n == 1: |
|
return [middle] |
|
first_half = non_timestep_inference(frame_0, middle, n=n//2) |
|
second_half = non_timestep_inference(middle, frame_1, n=n//2) |
|
if n%2: |
|
return [*first_half, middle, *second_half] |
|
else: |
|
return [*first_half, *second_half] """ |