|
import torch |
|
from comfy.model_management import get_torch_device, soft_empty_cache |
|
import numpy as np |
|
import typing |
|
from vfi_utils import InterpolationStateList, load_file_from_github_release, preprocess_frames, postprocess_frames, assert_batch_size |
|
import pathlib |
|
import warnings |
|
import gc |
|
|
|
MODEL_TYPE = pathlib.Path(__file__).parent.name |
|
device = get_torch_device() |
|
|
|
class STMFNet_VFI: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"ckpt_name": (["stmfnet.pth"], ), |
|
"frames": ("IMAGE", ), |
|
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), |
|
"multiplier": ("INT", {"default": 2, "min": 2, "max": 2}), |
|
"duplicate_first_last_frames": ("BOOLEAN", {"default": False}) |
|
}, |
|
"optional": { |
|
"optional_interpolation_states": ("INTERPOLATION_STATES", ) |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", ) |
|
FUNCTION = "vfi" |
|
CATEGORY = "ComfyUI-Frame-Interpolation/VFI" |
|
|
|
|
|
def vfi( |
|
self, |
|
ckpt_name: typing.AnyStr, |
|
frames: torch.Tensor, |
|
clear_cache_after_n_frames = 10, |
|
multiplier: typing.SupportsInt = 2, |
|
duplicate_first_last_frames: bool = False, |
|
optional_interpolation_states: InterpolationStateList = None, |
|
**kwargs |
|
): |
|
from .stmfnet_arch import STMFNet_Model |
|
if multiplier != 2: |
|
warnings.warn("Currently, ST-MFNet only supports 2x interpolation. The process will continue but please set multiplier=2 afterward") |
|
|
|
assert_batch_size(frames, batch_size=4, vfi_name="ST-MFNet") |
|
interpolation_states = optional_interpolation_states |
|
model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) |
|
model = STMFNet_Model() |
|
model.load_state_dict(torch.load(model_path)) |
|
model = model.eval().to(device) |
|
|
|
frames = preprocess_frames(frames) |
|
number_of_frames_processed_since_last_cleared_cuda_cache = 0 |
|
output_frames = [] |
|
for frame_itr in range(len(frames) - 3): |
|
|
|
if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr) and interpolation_states.is_frame_skipped(frame_itr + 1): |
|
continue |
|
|
|
|
|
frame0, frame1, frame2, frame3 = ( |
|
frames[frame_itr:frame_itr+1].float(), |
|
frames[frame_itr+1:frame_itr+2].float(), |
|
frames[frame_itr+2:frame_itr+3].float(), |
|
frames[frame_itr+3:frame_itr+4].float() |
|
) |
|
new_frame = model(frame0.to(device), frame1.to(device), frame2.to(device), frame3.to(device)).detach().cpu() |
|
number_of_frames_processed_since_last_cleared_cuda_cache += 2 |
|
|
|
if frame_itr == 0: |
|
output_frames.append(frame0) |
|
if duplicate_first_last_frames: |
|
output_frames.append(frame0) |
|
output_frames.append(frame1) |
|
output_frames.append(new_frame) |
|
output_frames.append(frame2) |
|
if frame_itr == len(frames) - 4: |
|
output_frames.append(frame3) |
|
if duplicate_first_last_frames: |
|
output_frames.append(frame3) |
|
|
|
|
|
if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames: |
|
print("Comfy-VFI: Clearing cache...", end = ' ') |
|
soft_empty_cache() |
|
number_of_frames_processed_since_last_cleared_cuda_cache = 0 |
|
print("Done cache clearing") |
|
gc.collect() |
|
|
|
dtype = torch.float32 |
|
output_frames = [frame.cpu().to(dtype=dtype) for frame in output_frames] |
|
out = torch.cat(output_frames, dim=0) |
|
|
|
print("Comfy-VFI: Final clearing cache...", end = ' ') |
|
soft_empty_cache() |
|
print("Done cache clearing") |
|
return (postprocess_frames(out), ) |