|
import torch |
|
from comfy.model_management import get_torch_device, soft_empty_cache |
|
import numpy as np |
|
import typing |
|
from vfi_utils import InterpolationStateList, load_file_from_github_release, preprocess_frames, postprocess_frames, assert_batch_size |
|
import pathlib |
|
import warnings |
|
from .flavr_arch import UNet_3D_3D, InputPadder |
|
import gc |
|
|
|
device = get_torch_device() |
|
NBR_FRAME = 4 |
|
|
|
def build_flavr(model_path): |
|
sd = torch.load(model_path)['state_dict'] |
|
sd = {k.partition("module.")[-1]:v for k,v in sd.items()} |
|
|
|
|
|
model = UNet_3D_3D("unet_18", n_inputs=NBR_FRAME, n_outputs=sd["outconv.1.weight"].shape[0] // 3, joinType="concat" , upmode="transpose") |
|
model.load_state_dict(sd) |
|
model.to(device).eval() |
|
del sd |
|
return model |
|
|
|
MODEL_TYPE = pathlib.Path(__file__).parent.name |
|
CKPT_NAMES = ["FLAVR_2x.pth", "FLAVR_4x.pth", "FLAVR_8x.pth"] |
|
|
|
class FLAVR_VFI: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"ckpt_name": (CKPT_NAMES, ), |
|
"frames": ("IMAGE", ), |
|
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), |
|
"multiplier": ("INT", {"default": 2, "min": 2, "max": 2}), |
|
"duplicate_first_last_frames": ("BOOLEAN", {"default": False}) |
|
}, |
|
"optional": { |
|
"optional_interpolation_states": ("INTERPOLATION_STATES", ) |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", ) |
|
FUNCTION = "vfi" |
|
CATEGORY = "ComfyUI-Frame-Interpolation/VFI" |
|
|
|
|
|
def vfi( |
|
self, |
|
ckpt_name: typing.AnyStr, |
|
frames: torch.Tensor, |
|
clear_cache_after_n_frames = 10, |
|
multiplier: typing.SupportsInt = 2, |
|
duplicate_first_last_frames: bool = False, |
|
optional_interpolation_states: InterpolationStateList = None, |
|
**kwargs |
|
): |
|
if multiplier != 2: |
|
warnings.warn("Currently, FLAVR only supports 2x interpolation. The process will continue but please set multiplier=2 afterward") |
|
|
|
assert_batch_size(frames, batch_size=4, vfi_name="ST-MFNet") |
|
interpolation_states = optional_interpolation_states |
|
model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) |
|
model = build_flavr(model_path) |
|
frames = preprocess_frames(frames) |
|
padder = InputPadder(frames.shape, 16) |
|
frames = padder.pad(frames) |
|
|
|
number_of_frames_processed_since_last_cleared_cuda_cache = 0 |
|
output_frames = [] |
|
for frame_itr in range(len(frames) - 3): |
|
|
|
if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr) and interpolation_states.is_frame_skipped(frame_itr + 1): |
|
continue |
|
|
|
|
|
frame0, frame1, frame2, frame3 = ( |
|
frames[frame_itr:frame_itr+1].float(), |
|
frames[frame_itr+1:frame_itr+2].float(), |
|
frames[frame_itr+2:frame_itr+3].float(), |
|
frames[frame_itr+3:frame_itr+4].float() |
|
) |
|
new_frame = model([frame0.to(device), frame1.to(device), frame2.to(device), frame3.to(device)])[0].detach().cpu() |
|
number_of_frames_processed_since_last_cleared_cuda_cache += 2 |
|
|
|
if frame_itr == 0: |
|
output_frames.append(frame0) |
|
if duplicate_first_last_frames: |
|
output_frames.append(frame0) |
|
output_frames.append(frame1) |
|
output_frames.append(new_frame) |
|
output_frames.append(frame2) |
|
if frame_itr == len(frames) - 4: |
|
output_frames.append(frame3) |
|
if duplicate_first_last_frames: |
|
output_frames.append(frame3) |
|
|
|
|
|
if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames: |
|
print("Comfy-VFI: Clearing cache...", end = ' ') |
|
soft_empty_cache() |
|
number_of_frames_processed_since_last_cleared_cuda_cache = 0 |
|
print("Done cache clearing") |
|
gc.collect() |
|
|
|
dtype = torch.float32 |
|
output_frames = [frame.cpu().to(dtype=dtype) for frame in output_frames] |
|
out = torch.cat(output_frames, dim=0) |
|
out = padder.unpad(out) |
|
|
|
print("Comfy-VFI: Final clearing cache...", end=' ') |
|
soft_empty_cache() |
|
print("Done cache clearing") |
|
return (postprocess_frames(out), ) |
|
|