File size: 11,578 Bytes
e61bb9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# thanks to https://github.com/n00mkrad for the inspiration and a bit of code. Also thanks for https://github.com/XmYx for the initial reorganization of this script
import os
from types import SimpleNamespace
import cv2
import torch
import shutil
import numpy as np
from tqdm import tqdm
from torch.nn import functional as F
import warnings
import _thread
from queue import Queue
import time
from .model.pytorch_msssim import ssim_matlab

from deforum_helpers.video_audio_utilities import ffmpeg_stitch_video
from deforum_helpers.general_utils import duplicate_pngs_from_folder

warnings.filterwarnings("ignore")

def run_rife_new_video_infer(
        output=None,
        model=None,
        fp16=False,
        UHD=False, # *Will be received as *True* if imgs/vid resolution is 2K or higher*
        scale=1.0,
        fps=None,
        deforum_models_path=None,
        raw_output_imgs_path=None,
        img_batch_id=None,
        ffmpeg_location=None,
        audio_track=None,
        interp_x_amount=2,
        slow_mo_enabled=False,
        slow_mo_x_amount=2,
        ffmpeg_crf=17,
        ffmpeg_preset='veryslow',
        keep_imgs=False,
        orig_vid_name = None,
        srt_path = None):

    args = SimpleNamespace()
    args.output = output
    args.modelDir = model
    args.fp16 = fp16
    args.UHD = UHD
    args.scale = scale
    args.fps = fps
    args.deforum_models_path = deforum_models_path
    args.raw_output_imgs_path = raw_output_imgs_path
    args.img_batch_id = img_batch_id
    args.ffmpeg_location = ffmpeg_location
    args.audio_track = audio_track
    args.interp_x_amount = interp_x_amount
    args.slow_mo_enabled = slow_mo_enabled
    args.slow_mo_x_amount = slow_mo_x_amount
    args.ffmpeg_crf = ffmpeg_crf
    args.ffmpeg_preset = ffmpeg_preset
    args.keep_imgs = keep_imgs
    args.orig_vid_name = orig_vid_name

    if args.UHD and args.scale == 1.0:
        args.scale = 0.5
        
    start_time = time.time()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.set_grad_enabled(False)
    if torch.cuda.is_available():
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        # TODO: Can/ need to handle this? currently it's always False and give errors if True but faster speeds on tensortcore equipped gpus?
        if (args.fp16):
            torch.set_default_tensor_type(torch.cuda.HalfTensor)
    if args.modelDir is not None:
        try:
            from .rife_new_gen.RIFE_HDv3 import Model
        except ImportError as e:
            raise ValueError(f"{args.modelDir} could not be found. Please contact deforum support {e}")
        except Exception as e:
            raise ValueError(f"An error occured while trying to import {args.modelDir}: {e}")
    else:
        print("Got a request to frame-interpolate but no valid frame interpolation engine value provided. Doing... nothing")
        return
   
    model = Model()
    if not hasattr(model, 'version'):
        model.version = 0
    model.load_model(args.modelDir, -1, deforum_models_path)
    model.eval()
    model.device()
    
    print(f"{args.modelDir}.pkl model successfully loaded into memory")
    print("Interpolation progress (it's OK if it finishes before 100%):")
   
    interpolated_path = os.path.join(args.raw_output_imgs_path, 'interpolated_frames_rife')
    # set custom name depending on if we interpolate after a run, or interpolate a video (related/unrelated to deforum, we don't know) directly from within the RIFE tab
    if args.orig_vid_name is not None: # interpolating a video (deforum or unrelated)
        custom_interp_path = "{}_{}".format(interpolated_path, args.orig_vid_name)
    else: # interpolating after a deforum run:
        custom_interp_path = "{}_{}".format(interpolated_path, args.img_batch_id)

    # In this folder we temporarily keep the original frames (converted/ copy-pasted and img format depends on scenario)
    # the convertion case is done to avert a problem with 24 and 32 mixed outputs from the same animation run
    temp_convert_raw_png_path = os.path.join(args.raw_output_imgs_path, "tmp_rife_folder")
    
    duplicate_pngs_from_folder(args.raw_output_imgs_path, temp_convert_raw_png_path, args.img_batch_id, args.orig_vid_name)
    
    videogen = []
    for f in os.listdir(temp_convert_raw_png_path):
        # double check for old _depth_ files, not really needed probably but keeping it for now
        if '_depth_' not in f:
            videogen.append(f)
    tot_frame = len(videogen)
    videogen.sort(key= lambda x:int(x.split('.')[0]))
    img_path = os.path.join(temp_convert_raw_png_path, videogen[0])
    lastframe = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
    videogen = videogen[1:]    
    h, w, _ = lastframe.shape
    vid_out = None

    if not os.path.exists(custom_interp_path):
        os.mkdir(custom_interp_path)

    tmp = max(128, int(128 / args.scale))
    ph = ((h - 1) // tmp + 1) * tmp
    pw = ((w - 1) // tmp + 1) * tmp
    padding = (0, pw - w, 0, ph - h)
    pbar = tqdm(total=tot_frame)

    write_buffer = Queue(maxsize=500)
    read_buffer  = Queue(maxsize=500)
    
    _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen, temp_convert_raw_png_path))
    _thread.start_new_thread(clear_write_buffer, (args, write_buffer, custom_interp_path))

    I1 = torch.from_numpy(np.transpose(lastframe, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
    I1 = pad_image(I1, args.fp16, padding)
    temp = None  # save lastframe when processing static frame
    
    while True:
        if temp is not None:
            frame = temp
            temp = None
        else:
            frame = read_buffer.get()
        if frame is None:
            break
        I0 = I1
        I1 = torch.from_numpy(np.transpose(frame, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
        I1 = pad_image(I1, args.fp16, padding)
        I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
        I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
        ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])

        break_flag = False
        if ssim > 0.996:
            frame = read_buffer.get()  # read a new frame
            if frame is None:
                break_flag = True
                frame = lastframe
            else:
                temp = frame
            I1 = torch.from_numpy(np.transpose(frame, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
            I1 = pad_image(I1, args.fp16, padding)
            I1 = model.inference(I0, I1, args.scale)
            I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
            ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
            frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]

        if ssim < 0.2:
            output = []
            for i in range(args.interp_x_amount - 1):
                output.append(I0)
        else:
            output = make_inference(model, I0, I1, args.interp_x_amount - 1, scale)
        
        write_buffer.put(lastframe)
        for mid in output:
            mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
            write_buffer.put(mid[:h, :w])
        pbar.update(1)
        lastframe = frame
        if break_flag:
            break

    write_buffer.put(lastframe)

    while (not write_buffer.empty()):
        time.sleep(0.1)
    pbar.close()
    shutil.rmtree(temp_convert_raw_png_path)
    
    print(f"Interpolation \033[0;32mdone\033[0m in {time.time()-start_time:.2f} seconds!")
    # stitch video from interpolated frames, and add audio if needed
    try:
        print (f"*Passing interpolated frames to ffmpeg...*")
        vid_out_path = stitch_video(args.img_batch_id, args.fps, custom_interp_path, args.audio_track, args.ffmpeg_location, args.interp_x_amount, args.slow_mo_enabled, args.slow_mo_x_amount, args.ffmpeg_crf, args.ffmpeg_preset, args.keep_imgs, args.orig_vid_name, srt_path=srt_path)
        # remove folder with raw (non-interpolated) vid input frames in case of input VID and not PNGs
        if orig_vid_name is not None:
            shutil.rmtree(raw_output_imgs_path)
        return vid_out_path
    except Exception as e:
        print(f'Video stitching gone wrong. *Interpolated frames were saved to HD as backup!*. Actual error: {e}')
    
def clear_write_buffer(user_args, write_buffer, custom_interp_path):
    cnt = 0

    while True:
        item = write_buffer.get()
        if item is None:
            break
        filename = '{}/{:0>9d}.png'.format(custom_interp_path, cnt)

        cv2.imwrite(filename, item[:, :, ::-1])

        cnt += 1

def build_read_buffer(user_args, read_buffer, videogen, temp_convert_raw_png_path):
    for frame in videogen:
        if not temp_convert_raw_png_path is None:
            img_path = os.path.join(temp_convert_raw_png_path, frame)
            frame = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
        read_buffer.put(frame)
    read_buffer.put(None)

def make_inference(model, I0, I1, n, scale):
    if model.version >= 3.9:
        res = []
        for i in range(n):
            res.append(model.inference(I0, I1, (i + 1) * 1. / (n + 1), scale))
        return res
    else:
        middle = model.inference(I0, I1, scale)
        if n == 1:
            return [middle]
        first_half = make_inference(model, I0, middle, n=n // 2, scale=scale)
        second_half = make_inference(model, middle, I1, n=n // 2, scale=scale)
        if n % 2:
            return [*first_half, middle, *second_half]
        else:
            return [*first_half, *second_half]

def pad_image(img, fp16, padding):
    if (fp16):
        return F.pad(img, padding).half()
    else:
        return F.pad(img, padding)

# TODO: move to fream_interpolation and add FILM to it!
def stitch_video(img_batch_id, fps, img_folder_path, audio_path, ffmpeg_location, interp_x_amount, slow_mo_enabled, slow_mo_x_amount, f_crf, f_preset, keep_imgs, orig_vid_name, srt_path=None):
    parent_folder = os.path.dirname(img_folder_path)
    grandparent_folder = os.path.dirname(parent_folder)
    if orig_vid_name is not None:
        mp4_path = os.path.join(grandparent_folder, str(orig_vid_name) +'_RIFE_' + 'x' + str(interp_x_amount))
    else:
        mp4_path = os.path.join(parent_folder, str(img_batch_id) +'_RIFE_' + 'x' + str(interp_x_amount))

    if slow_mo_enabled:
        mp4_path = mp4_path + '_slomo_x' + str(slow_mo_x_amount)
    mp4_path = mp4_path + '.mp4'

    t = os.path.join(img_folder_path, "%09d.png")
    add_soundtrack = 'None'
    if not audio_path is None:
        add_soundtrack = 'File'
        
    exception_raised = False
    try:
        ffmpeg_stitch_video(ffmpeg_location=ffmpeg_location, fps=fps, outmp4_path=mp4_path, stitch_from_frame=0, stitch_to_frame=1000000, imgs_path=t, add_soundtrack=add_soundtrack, audio_path=audio_path, crf=f_crf, preset=f_preset, srt_path=srt_path)
    except Exception as e:
        exception_raised = True
        print(f"An error occurred while stitching the video: {e}")

    if not exception_raised and not keep_imgs:
        shutil.rmtree(img_folder_path)

    if (keep_imgs and orig_vid_name is not None) or (orig_vid_name is not None and exception_raised is True):
        shutil.move(img_folder_path, grandparent_folder)

    return mp4_path