jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import yaml
import os
from torch.hub import download_url_to_file, get_dir
from urllib.parse import urlparse
import torch
import typing
import traceback
import einops
import gc
import torchvision.transforms.functional as transform
from comfy.model_management import soft_empty_cache, get_torch_device
import numpy as np
BASE_MODEL_DOWNLOAD_URLS = [
"https://github.com/styler00dollar/VSGAN-tensorrt-docker/releases/download/models/",
"https://github.com/Fannovel16/ComfyUI-Frame-Interpolation/releases/download/models/",
"https://github.com/dajes/frame-interpolation-pytorch/releases/download/v1.0.0/"
]
config_path = os.path.join(os.path.dirname(__file__), "./config.yaml")
if os.path.exists(config_path):
config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
else:
raise Exception("config.yaml file is neccessary, plz recreate the config file by downloading it from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation")
DEVICE = get_torch_device()
class InterpolationStateList():
def __init__(self, frame_indices: typing.List[int], is_skip_list: bool):
self.frame_indices = frame_indices
self.is_skip_list = is_skip_list
def is_frame_skipped(self, frame_index):
is_frame_in_list = frame_index in self.frame_indices
return self.is_skip_list and is_frame_in_list or not self.is_skip_list and not is_frame_in_list
class MakeInterpolationStateList:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"frame_indices": ("STRING", {"multiline": True, "default": "1,2,3"}),
"is_skip_list": ("BOOLEAN", {"default": True},),
},
}
RETURN_TYPES = ("INTERPOLATION_STATES",)
FUNCTION = "create_options"
CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
def create_options(self, frame_indices: str, is_skip_list: bool):
frame_indices_list = [int(item) for item in frame_indices.split(',')]
interpolation_state_list = InterpolationStateList(
frame_indices=frame_indices_list,
is_skip_list=is_skip_list,
)
return (interpolation_state_list,)
def get_ckpt_container_path(model_type):
return os.path.abspath(os.path.join(os.path.dirname(__file__), config["ckpts_path"], model_type))
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
file_name = os.path.basename(parts.path)
if file_name is not None:
file_name = file_name
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
def load_file_from_github_release(model_type, ckpt_name):
error_strs = []
for i, base_model_download_url in enumerate(BASE_MODEL_DOWNLOAD_URLS):
try:
return load_file_from_url(base_model_download_url + ckpt_name, get_ckpt_container_path(model_type))
except Exception:
traceback_str = traceback.format_exc()
if i < len(BASE_MODEL_DOWNLOAD_URLS) - 1:
print("Failed! Trying another endpoint.")
error_strs.append(f"Error when downloading from: {base_model_download_url + ckpt_name}\n\n{traceback_str}")
error_str = '\n\n'.join(error_strs)
raise Exception(f"Tried all GitHub base urls to download {ckpt_name} but no suceess. Below is the error log:\n\n{error_str}")
def load_file_from_direct_url(model_type, url):
return load_file_from_url(url, get_ckpt_container_path(model_type))
def preprocess_frames(frames):
return einops.rearrange(frames[..., :3], "n h w c -> n c h w")
def postprocess_frames(frames):
return einops.rearrange(frames, "n c h w -> n h w c")[..., :3].cpu()
def assert_batch_size(frames, batch_size=2, vfi_name=None):
subject_verb = "Most VFI models require" if vfi_name is None else f"VFI model {vfi_name} requires"
assert len(frames) >= batch_size, f"{subject_verb} at least {batch_size} frames to work with, only found {frames.shape[0]}. Please check the frame input using PreviewImage."
def _generic_frame_loop(
frames,
clear_cache_after_n_frames,
multiplier: typing.Union[typing.SupportsInt, typing.List],
return_middle_frame_function,
*return_middle_frame_function_args,
interpolation_states: InterpolationStateList = None,
use_timestep=True,
dtype=torch.float16,
final_logging=True):
#https://github.com/hzwer/Practical-RIFE/blob/main/inference_video.py#L169
def non_timestep_inference(frame0, frame1, n):
middle = return_middle_frame_function(frame0, frame1, None, *return_middle_frame_function_args)
if n == 1:
return [middle]
first_half = non_timestep_inference(frame0, middle, n=n//2)
second_half = non_timestep_inference(middle, frame1, n=n//2)
if n%2:
return [*first_half, middle, *second_half]
else:
return [*first_half, *second_half]
output_frames = torch.zeros(multiplier*frames.shape[0], *frames.shape[1:], dtype=dtype, device="cpu")
out_len = 0
number_of_frames_processed_since_last_cleared_cuda_cache = 0
for frame_itr in range(len(frames) - 1): # Skip the final frame since there are no frames after it
frame0 = frames[frame_itr:frame_itr+1]
output_frames[out_len] = frame0 # Start with first frame
out_len += 1
# Ensure that input frames are in fp32 - the same dtype as model
frame0 = frame0.to(dtype=torch.float32)
frame1 = frames[frame_itr+1:frame_itr+2].to(dtype=torch.float32)
if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr):
continue
# Generate and append a batch of middle frames
middle_frame_batches = []
if use_timestep:
for middle_i in range(1, multiplier):
timestep = middle_i/multiplier
middle_frame = return_middle_frame_function(
frame0.to(DEVICE),
frame1.to(DEVICE),
timestep,
*return_middle_frame_function_args
).detach().cpu()
middle_frame_batches.append(middle_frame.to(dtype=dtype))
else:
middle_frames = non_timestep_inference(frame0.to(DEVICE), frame1.to(DEVICE), multiplier - 1)
middle_frame_batches.extend(torch.cat(middle_frames, dim=0).detach().cpu().to(dtype=dtype))
# Copy middle frames to output
for middle_frame in middle_frame_batches:
output_frames[out_len] = middle_frame
out_len += 1
number_of_frames_processed_since_last_cleared_cuda_cache += 1
# Try to avoid a memory overflow by clearing cuda cache regularly
if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames:
print("Comfy-VFI: Clearing cache...", end=' ')
soft_empty_cache()
number_of_frames_processed_since_last_cleared_cuda_cache = 0
print("Done cache clearing")
gc.collect()
if final_logging:
print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}")
# Append final frame
output_frames[out_len] = frames[-1:]
out_len += 1
# clear cache for courtesy
if final_logging:
print("Comfy-VFI: Final clearing cache...", end = ' ')
soft_empty_cache()
if final_logging:
print("Done cache clearing")
return output_frames[:out_len]
def generic_frame_loop(
model_name,
frames,
clear_cache_after_n_frames,
multiplier: typing.Union[typing.SupportsInt, typing.List],
return_middle_frame_function,
*return_middle_frame_function_args,
interpolation_states: InterpolationStateList = None,
use_timestep=True,
dtype=torch.float32):
assert_batch_size(frames, vfi_name=model_name.replace('_', ' ').replace('VFI', ''))
if type(multiplier) == int:
return _generic_frame_loop(
frames,
clear_cache_after_n_frames,
multiplier,
return_middle_frame_function,
*return_middle_frame_function_args,
interpolation_states=interpolation_states,
use_timestep=use_timestep,
dtype=dtype
)
if type(multiplier) == list:
multipliers = list(map(int, multiplier))
multipliers += [2] * (len(frames) - len(multipliers) - 1)
frame_batches = []
for frame_itr in range(len(frames) - 1):
multiplier = multipliers[frame_itr]
if multiplier == 0: continue
frame_batch = _generic_frame_loop(
frames[frame_itr:frame_itr+2],
clear_cache_after_n_frames,
multiplier,
return_middle_frame_function,
*return_middle_frame_function_args,
interpolation_states=interpolation_states,
use_timestep=use_timestep,
dtype=dtype,
final_logging=False
)
if frame_itr != len(frames) - 2: # Not append last frame unless this batch is the last one
frame_batch = frame_batch[:-1]
frame_batches.append(frame_batch)
output_frames = torch.cat(frame_batches)
print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}")
return output_frames
raise NotImplementedError(f"multipiler of {type(multiplier)}")
class FloatToInt:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"float": ("FLOAT", {"default": 0, 'min': 0, 'step': 0.01})
}
}
RETURN_TYPES = ("INT",)
FUNCTION = "convert"
CATEGORY = "ComfyUI-Frame-Interpolation"
def convert(self, float):
if hasattr(float, "__iter__"):
return (list(map(int, float)),)
return (int(float),)
""" def generic_4frame_loop(
frames,
clear_cache_after_n_frames,
multiplier: typing.SupportsInt,
return_middle_frame_function,
*return_middle_frame_function_args,
interpolation_states: InterpolationStateList = None,
use_timestep=False):
if use_timestep: raise NotImplementedError("Timestep 4 frame VFI model")
def non_timestep_inference(frame_0, frame_1, frame_2, frame_3, n):
middle = return_middle_frame_function(frame_0, frame_1, None, *return_middle_frame_function_args)
if n == 1:
return [middle]
first_half = non_timestep_inference(frame_0, middle, n=n//2)
second_half = non_timestep_inference(middle, frame_1, n=n//2)
if n%2:
return [*first_half, middle, *second_half]
else:
return [*first_half, *second_half] """