Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- dataset.py +109 -0
- inference_img.py +111 -0
- inference_video.py +297 -0
- train.py +155 -0
dataset.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import ast
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
from torch.utils.data import DataLoader, Dataset
|
8 |
+
|
9 |
+
cv2.setNumThreads(1)
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
class VimeoDataset(Dataset):
|
12 |
+
def __init__(self, dataset_name, batch_size=32):
|
13 |
+
self.batch_size = batch_size
|
14 |
+
self.dataset_name = dataset_name
|
15 |
+
self.h = 256
|
16 |
+
self.w = 448
|
17 |
+
self.data_root = 'vimeo_triplet'
|
18 |
+
self.image_root = os.path.join(self.data_root, 'sequences')
|
19 |
+
train_fn = os.path.join(self.data_root, 'tri_trainlist.txt')
|
20 |
+
test_fn = os.path.join(self.data_root, 'tri_testlist.txt')
|
21 |
+
with open(train_fn, 'r') as f:
|
22 |
+
self.trainlist = f.read().splitlines()
|
23 |
+
with open(test_fn, 'r') as f:
|
24 |
+
self.testlist = f.read().splitlines()
|
25 |
+
self.load_data()
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.meta_data)
|
29 |
+
|
30 |
+
def load_data(self):
|
31 |
+
cnt = int(len(self.trainlist) * 0.95)
|
32 |
+
if self.dataset_name == 'train':
|
33 |
+
self.meta_data = self.trainlist[:cnt]
|
34 |
+
elif self.dataset_name == 'test':
|
35 |
+
self.meta_data = self.testlist
|
36 |
+
else:
|
37 |
+
self.meta_data = self.trainlist[cnt:]
|
38 |
+
|
39 |
+
def crop(self, img0, gt, img1, h, w):
|
40 |
+
ih, iw, _ = img0.shape
|
41 |
+
x = np.random.randint(0, ih - h + 1)
|
42 |
+
y = np.random.randint(0, iw - w + 1)
|
43 |
+
img0 = img0[x:x+h, y:y+w, :]
|
44 |
+
img1 = img1[x:x+h, y:y+w, :]
|
45 |
+
gt = gt[x:x+h, y:y+w, :]
|
46 |
+
return img0, gt, img1
|
47 |
+
|
48 |
+
def getimg(self, index):
|
49 |
+
imgpath = os.path.join(self.image_root, self.meta_data[index])
|
50 |
+
imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png']
|
51 |
+
|
52 |
+
# Load images
|
53 |
+
img0 = cv2.imread(imgpaths[0])
|
54 |
+
gt = cv2.imread(imgpaths[1])
|
55 |
+
img1 = cv2.imread(imgpaths[2])
|
56 |
+
timestep = 0.5
|
57 |
+
return img0, gt, img1, timestep
|
58 |
+
|
59 |
+
# RIFEm with Vimeo-Septuplet
|
60 |
+
# imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png', imgpath + '/im4.png', imgpath + '/im5.png', imgpath + '/im6.png', imgpath + '/im7.png']
|
61 |
+
# ind = [0, 1, 2, 3, 4, 5, 6]
|
62 |
+
# random.shuffle(ind)
|
63 |
+
# ind = ind[:3]
|
64 |
+
# ind.sort()
|
65 |
+
# img0 = cv2.imread(imgpaths[ind[0]])
|
66 |
+
# gt = cv2.imread(imgpaths[ind[1]])
|
67 |
+
# img1 = cv2.imread(imgpaths[ind[2]])
|
68 |
+
# timestep = (ind[1] - ind[0]) * 1.0 / (ind[2] - ind[0] + 1e-6)
|
69 |
+
|
70 |
+
def __getitem__(self, index):
|
71 |
+
img0, gt, img1, timestep = self.getimg(index)
|
72 |
+
if self.dataset_name == 'train':
|
73 |
+
img0, gt, img1 = self.crop(img0, gt, img1, 224, 224)
|
74 |
+
if random.uniform(0, 1) < 0.5:
|
75 |
+
img0 = img0[:, :, ::-1]
|
76 |
+
img1 = img1[:, :, ::-1]
|
77 |
+
gt = gt[:, :, ::-1]
|
78 |
+
if random.uniform(0, 1) < 0.5:
|
79 |
+
img0 = img0[::-1]
|
80 |
+
img1 = img1[::-1]
|
81 |
+
gt = gt[::-1]
|
82 |
+
if random.uniform(0, 1) < 0.5:
|
83 |
+
img0 = img0[:, ::-1]
|
84 |
+
img1 = img1[:, ::-1]
|
85 |
+
gt = gt[:, ::-1]
|
86 |
+
if random.uniform(0, 1) < 0.5:
|
87 |
+
tmp = img1
|
88 |
+
img1 = img0
|
89 |
+
img0 = tmp
|
90 |
+
timestep = 1 - timestep
|
91 |
+
# random rotation
|
92 |
+
p = random.uniform(0, 1)
|
93 |
+
if p < 0.25:
|
94 |
+
img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE)
|
95 |
+
gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE)
|
96 |
+
img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE)
|
97 |
+
elif p < 0.5:
|
98 |
+
img0 = cv2.rotate(img0, cv2.ROTATE_180)
|
99 |
+
gt = cv2.rotate(gt, cv2.ROTATE_180)
|
100 |
+
img1 = cv2.rotate(img1, cv2.ROTATE_180)
|
101 |
+
elif p < 0.75:
|
102 |
+
img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
103 |
+
gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
104 |
+
img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
105 |
+
img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
|
106 |
+
img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
|
107 |
+
gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
|
108 |
+
timestep = torch.tensor(timestep).reshape(1, 1, 1)
|
109 |
+
return torch.cat((img0, img1, gt), 0), timestep
|
inference_img.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import warnings
|
7 |
+
warnings.filterwarnings("ignore")
|
8 |
+
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
torch.set_grad_enabled(False)
|
11 |
+
if torch.cuda.is_available():
|
12 |
+
torch.backends.cudnn.enabled = True
|
13 |
+
torch.backends.cudnn.benchmark = True
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
|
16 |
+
parser.add_argument('--img', dest='img', nargs=2, required=True)
|
17 |
+
parser.add_argument('--exp', default=4, type=int)
|
18 |
+
parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
|
19 |
+
parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
|
20 |
+
parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
|
21 |
+
parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
try:
|
26 |
+
try:
|
27 |
+
try:
|
28 |
+
from model.RIFE_HDv2 import Model
|
29 |
+
model = Model()
|
30 |
+
model.load_model(args.modelDir, -1)
|
31 |
+
print("Loaded v2.x HD model.")
|
32 |
+
except:
|
33 |
+
from train_log.RIFE_HDv3 import Model
|
34 |
+
model = Model()
|
35 |
+
model.load_model(args.modelDir, -1)
|
36 |
+
print("Loaded v3.x HD model.")
|
37 |
+
except:
|
38 |
+
from model.RIFE_HD import Model
|
39 |
+
model = Model()
|
40 |
+
model.load_model(args.modelDir, -1)
|
41 |
+
print("Loaded v1.x HD model")
|
42 |
+
except:
|
43 |
+
from model.RIFE import Model
|
44 |
+
model = Model()
|
45 |
+
model.load_model(args.modelDir, -1)
|
46 |
+
print("Loaded ArXiv-RIFE model")
|
47 |
+
model.eval()
|
48 |
+
model.device()
|
49 |
+
|
50 |
+
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
|
51 |
+
img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
|
52 |
+
img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
|
53 |
+
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
|
54 |
+
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
|
55 |
+
|
56 |
+
else:
|
57 |
+
img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED)
|
58 |
+
img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED)
|
59 |
+
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
|
60 |
+
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
|
61 |
+
|
62 |
+
n, c, h, w = img0.shape
|
63 |
+
ph = ((h - 1) // 32 + 1) * 32
|
64 |
+
pw = ((w - 1) // 32 + 1) * 32
|
65 |
+
padding = (0, pw - w, 0, ph - h)
|
66 |
+
img0 = F.pad(img0, padding)
|
67 |
+
img1 = F.pad(img1, padding)
|
68 |
+
|
69 |
+
|
70 |
+
if args.ratio:
|
71 |
+
img_list = [img0]
|
72 |
+
img0_ratio = 0.0
|
73 |
+
img1_ratio = 1.0
|
74 |
+
if args.ratio <= img0_ratio + args.rthreshold / 2:
|
75 |
+
middle = img0
|
76 |
+
elif args.ratio >= img1_ratio - args.rthreshold / 2:
|
77 |
+
middle = img1
|
78 |
+
else:
|
79 |
+
tmp_img0 = img0
|
80 |
+
tmp_img1 = img1
|
81 |
+
for inference_cycle in range(args.rmaxcycles):
|
82 |
+
middle = model.inference(tmp_img0, tmp_img1)
|
83 |
+
middle_ratio = ( img0_ratio + img1_ratio ) / 2
|
84 |
+
if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2):
|
85 |
+
break
|
86 |
+
if args.ratio > middle_ratio:
|
87 |
+
tmp_img0 = middle
|
88 |
+
img0_ratio = middle_ratio
|
89 |
+
else:
|
90 |
+
tmp_img1 = middle
|
91 |
+
img1_ratio = middle_ratio
|
92 |
+
img_list.append(middle)
|
93 |
+
img_list.append(img1)
|
94 |
+
else:
|
95 |
+
img_list = [img0, img1]
|
96 |
+
for i in range(args.exp):
|
97 |
+
tmp = []
|
98 |
+
for j in range(len(img_list) - 1):
|
99 |
+
mid = model.inference(img_list[j], img_list[j + 1])
|
100 |
+
tmp.append(img_list[j])
|
101 |
+
tmp.append(mid)
|
102 |
+
tmp.append(img1)
|
103 |
+
img_list = tmp
|
104 |
+
|
105 |
+
if not os.path.exists('output'):
|
106 |
+
os.mkdir('output')
|
107 |
+
for i in range(len(img_list)):
|
108 |
+
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
|
109 |
+
cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
|
110 |
+
else:
|
111 |
+
cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
|
inference_video.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torch.nn import functional as F
|
8 |
+
import warnings
|
9 |
+
import _thread
|
10 |
+
import skvideo.io
|
11 |
+
from queue import Queue, Empty
|
12 |
+
from model.pytorch_msssim import ssim_matlab
|
13 |
+
|
14 |
+
warnings.filterwarnings("ignore")
|
15 |
+
|
16 |
+
def transferAudio(sourceVideo, targetVideo):
|
17 |
+
import shutil
|
18 |
+
import moviepy.editor
|
19 |
+
tempAudioFileName = "./temp/audio.mkv"
|
20 |
+
|
21 |
+
# split audio from original video file and store in "temp" directory
|
22 |
+
if True:
|
23 |
+
|
24 |
+
# clear old "temp" directory if it exits
|
25 |
+
if os.path.isdir("temp"):
|
26 |
+
# remove temp directory
|
27 |
+
shutil.rmtree("temp")
|
28 |
+
# create new "temp" directory
|
29 |
+
os.makedirs("temp")
|
30 |
+
# extract audio from video
|
31 |
+
os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName))
|
32 |
+
|
33 |
+
targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
|
34 |
+
os.rename(targetVideo, targetNoAudio)
|
35 |
+
# combine audio file and new video file
|
36 |
+
os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
|
37 |
+
|
38 |
+
if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
|
39 |
+
tempAudioFileName = "./temp/audio.m4a"
|
40 |
+
os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName))
|
41 |
+
os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
|
42 |
+
if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
|
43 |
+
os.rename(targetNoAudio, targetVideo)
|
44 |
+
print("Audio transfer failed. Interpolated video will have no audio")
|
45 |
+
else:
|
46 |
+
print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
|
47 |
+
|
48 |
+
# remove audio-less video
|
49 |
+
os.remove(targetNoAudio)
|
50 |
+
else:
|
51 |
+
os.remove(targetNoAudio)
|
52 |
+
|
53 |
+
# remove temp directory
|
54 |
+
shutil.rmtree("temp")
|
55 |
+
|
56 |
+
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
|
57 |
+
parser.add_argument('--video', dest='video', type=str, default=None)
|
58 |
+
parser.add_argument('--output', dest='output', type=str, default=None)
|
59 |
+
parser.add_argument('--img', dest='img', type=str, default=None)
|
60 |
+
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
|
61 |
+
parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
|
62 |
+
parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
|
63 |
+
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
|
64 |
+
parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video')
|
65 |
+
parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
|
66 |
+
parser.add_argument('--fps', dest='fps', type=int, default=None)
|
67 |
+
parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
|
68 |
+
parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension')
|
69 |
+
parser.add_argument('--exp', dest='exp', type=int, default=1)
|
70 |
+
args = parser.parse_args()
|
71 |
+
assert (not args.video is None or not args.img is None)
|
72 |
+
if args.skip:
|
73 |
+
print("skip flag is abandoned, please refer to issue #207.")
|
74 |
+
if args.UHD and args.scale==1.0:
|
75 |
+
args.scale = 0.5
|
76 |
+
assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
|
77 |
+
if not args.img is None:
|
78 |
+
args.png = True
|
79 |
+
|
80 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
81 |
+
torch.set_grad_enabled(False)
|
82 |
+
if torch.cuda.is_available():
|
83 |
+
torch.backends.cudnn.enabled = True
|
84 |
+
torch.backends.cudnn.benchmark = True
|
85 |
+
if(args.fp16):
|
86 |
+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
87 |
+
|
88 |
+
try:
|
89 |
+
try:
|
90 |
+
try:
|
91 |
+
from model.RIFE_HDv2 import Model
|
92 |
+
model = Model()
|
93 |
+
model.load_model(args.modelDir, -1)
|
94 |
+
print("Loaded v2.x HD model.")
|
95 |
+
except:
|
96 |
+
from train_log.RIFE_HDv3 import Model
|
97 |
+
model = Model()
|
98 |
+
model.load_model(args.modelDir, -1)
|
99 |
+
print("Loaded v3.x HD model.")
|
100 |
+
except:
|
101 |
+
from model.RIFE_HD import Model
|
102 |
+
model = Model()
|
103 |
+
model.load_model(args.modelDir, -1)
|
104 |
+
print("Loaded v1.x HD model")
|
105 |
+
except:
|
106 |
+
from model.RIFE import Model
|
107 |
+
model = Model()
|
108 |
+
model.load_model(args.modelDir, -1)
|
109 |
+
print("Loaded ArXiv-RIFE model")
|
110 |
+
model.eval()
|
111 |
+
model.device()
|
112 |
+
|
113 |
+
if not args.video is None:
|
114 |
+
videoCapture = cv2.VideoCapture(args.video)
|
115 |
+
fps = videoCapture.get(cv2.CAP_PROP_FPS)
|
116 |
+
tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
|
117 |
+
videoCapture.release()
|
118 |
+
if args.fps is None:
|
119 |
+
fpsNotAssigned = True
|
120 |
+
args.fps = fps * (2 ** args.exp)
|
121 |
+
else:
|
122 |
+
fpsNotAssigned = False
|
123 |
+
videogen = skvideo.io.vreader(args.video)
|
124 |
+
lastframe = next(videogen)
|
125 |
+
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
126 |
+
video_path_wo_ext, ext = os.path.splitext(args.video)
|
127 |
+
print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps))
|
128 |
+
if args.png == False and fpsNotAssigned == True:
|
129 |
+
print("The audio will be merged after interpolation process")
|
130 |
+
else:
|
131 |
+
print("Will not merge audio because using png or fps flag!")
|
132 |
+
else:
|
133 |
+
videogen = []
|
134 |
+
for f in os.listdir(args.img):
|
135 |
+
if 'png' in f:
|
136 |
+
videogen.append(f)
|
137 |
+
tot_frame = len(videogen)
|
138 |
+
videogen.sort(key= lambda x:int(x[:-4]))
|
139 |
+
lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
|
140 |
+
videogen = videogen[1:]
|
141 |
+
h, w, _ = lastframe.shape
|
142 |
+
vid_out_name = None
|
143 |
+
vid_out = None
|
144 |
+
if args.png:
|
145 |
+
if not os.path.exists('vid_out'):
|
146 |
+
os.mkdir('vid_out')
|
147 |
+
else:
|
148 |
+
if args.output is not None:
|
149 |
+
vid_out_name = args.output
|
150 |
+
else:
|
151 |
+
vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext)
|
152 |
+
vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))
|
153 |
+
|
154 |
+
def clear_write_buffer(user_args, write_buffer):
|
155 |
+
cnt = 0
|
156 |
+
while True:
|
157 |
+
item = write_buffer.get()
|
158 |
+
if item is None:
|
159 |
+
break
|
160 |
+
if user_args.png:
|
161 |
+
cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
|
162 |
+
cnt += 1
|
163 |
+
else:
|
164 |
+
vid_out.write(item[:, :, ::-1])
|
165 |
+
|
166 |
+
def build_read_buffer(user_args, read_buffer, videogen):
|
167 |
+
try:
|
168 |
+
for frame in videogen:
|
169 |
+
if not user_args.img is None:
|
170 |
+
frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
|
171 |
+
if user_args.montage:
|
172 |
+
frame = frame[:, left: left + w]
|
173 |
+
read_buffer.put(frame)
|
174 |
+
except:
|
175 |
+
pass
|
176 |
+
read_buffer.put(None)
|
177 |
+
|
178 |
+
def make_inference(I0, I1, n):
|
179 |
+
global model
|
180 |
+
middle = model.inference(I0, I1, args.scale)
|
181 |
+
if n == 1:
|
182 |
+
return [middle]
|
183 |
+
first_half = make_inference(I0, middle, n=n//2)
|
184 |
+
second_half = make_inference(middle, I1, n=n//2)
|
185 |
+
if n%2:
|
186 |
+
return [*first_half, middle, *second_half]
|
187 |
+
else:
|
188 |
+
return [*first_half, *second_half]
|
189 |
+
|
190 |
+
def pad_image(img):
|
191 |
+
if(args.fp16):
|
192 |
+
return F.pad(img, padding).half()
|
193 |
+
else:
|
194 |
+
return F.pad(img, padding)
|
195 |
+
|
196 |
+
if args.montage:
|
197 |
+
left = w // 4
|
198 |
+
w = w // 2
|
199 |
+
tmp = max(32, int(32 / args.scale))
|
200 |
+
ph = ((h - 1) // tmp + 1) * tmp
|
201 |
+
pw = ((w - 1) // tmp + 1) * tmp
|
202 |
+
padding = (0, pw - w, 0, ph - h)
|
203 |
+
pbar = tqdm(total=tot_frame)
|
204 |
+
if args.montage:
|
205 |
+
lastframe = lastframe[:, left: left + w]
|
206 |
+
write_buffer = Queue(maxsize=500)
|
207 |
+
read_buffer = Queue(maxsize=500)
|
208 |
+
_thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
|
209 |
+
_thread.start_new_thread(clear_write_buffer, (args, write_buffer))
|
210 |
+
|
211 |
+
I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
212 |
+
I1 = pad_image(I1)
|
213 |
+
temp = None # save lastframe when processing static frame
|
214 |
+
|
215 |
+
while True:
|
216 |
+
if temp is not None:
|
217 |
+
frame = temp
|
218 |
+
temp = None
|
219 |
+
else:
|
220 |
+
frame = read_buffer.get()
|
221 |
+
if frame is None:
|
222 |
+
break
|
223 |
+
I0 = I1
|
224 |
+
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
225 |
+
I1 = pad_image(I1)
|
226 |
+
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
|
227 |
+
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
228 |
+
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
229 |
+
|
230 |
+
break_flag = False
|
231 |
+
if ssim > 0.996:
|
232 |
+
frame = read_buffer.get() # read a new frame
|
233 |
+
if frame is None:
|
234 |
+
break_flag = True
|
235 |
+
frame = lastframe
|
236 |
+
else:
|
237 |
+
temp = frame
|
238 |
+
I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
|
239 |
+
I1 = pad_image(I1)
|
240 |
+
I1 = model.inference(I0, I1, args.scale)
|
241 |
+
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
242 |
+
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
243 |
+
frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
|
244 |
+
|
245 |
+
if ssim < 0.2:
|
246 |
+
output = []
|
247 |
+
for i in range((2 ** args.exp) - 1):
|
248 |
+
output.append(I0)
|
249 |
+
'''
|
250 |
+
output = []
|
251 |
+
step = 1 / (2 ** args.exp)
|
252 |
+
alpha = 0
|
253 |
+
for i in range((2 ** args.exp) - 1):
|
254 |
+
alpha += step
|
255 |
+
beta = 1-alpha
|
256 |
+
output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
|
257 |
+
'''
|
258 |
+
else:
|
259 |
+
output = make_inference(I0, I1, 2**args.exp-1) if args.exp else []
|
260 |
+
|
261 |
+
if args.montage:
|
262 |
+
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
|
263 |
+
for mid in output:
|
264 |
+
mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
|
265 |
+
write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
|
266 |
+
else:
|
267 |
+
write_buffer.put(lastframe)
|
268 |
+
for mid in output:
|
269 |
+
mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
|
270 |
+
write_buffer.put(mid[:h, :w])
|
271 |
+
pbar.update(1)
|
272 |
+
lastframe = frame
|
273 |
+
if break_flag:
|
274 |
+
break
|
275 |
+
|
276 |
+
if args.montage:
|
277 |
+
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
|
278 |
+
else:
|
279 |
+
write_buffer.put(lastframe)
|
280 |
+
|
281 |
+
write_buffer.put(None)
|
282 |
+
|
283 |
+
import time
|
284 |
+
while(not write_buffer.empty()):
|
285 |
+
time.sleep(0.1)
|
286 |
+
pbar.close()
|
287 |
+
if not vid_out is None:
|
288 |
+
vid_out.release()
|
289 |
+
|
290 |
+
# move audio to new video file if appropriate
|
291 |
+
if args.png == False and fpsNotAssigned == True and not args.video is None:
|
292 |
+
try:
|
293 |
+
transferAudio(args.video, vid_out_name)
|
294 |
+
except:
|
295 |
+
print("Audio transfer failed. Interpolated video will have no audio")
|
296 |
+
targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
|
297 |
+
os.rename(targetNoAudio, vid_out_name)
|
train.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
from model.RIFE import Model
|
12 |
+
from dataset import *
|
13 |
+
from torch.utils.data import DataLoader, Dataset
|
14 |
+
from torch.utils.tensorboard import SummaryWriter
|
15 |
+
from torch.utils.data.distributed import DistributedSampler
|
16 |
+
|
17 |
+
device = torch.device("cuda")
|
18 |
+
|
19 |
+
log_path = 'train_log'
|
20 |
+
|
21 |
+
def get_learning_rate(step):
|
22 |
+
if step < 2000:
|
23 |
+
mul = step / 2000.
|
24 |
+
return 3e-4 * mul
|
25 |
+
else:
|
26 |
+
mul = np.cos((step - 2000) / (args.epoch * args.step_per_epoch - 2000.) * math.pi) * 0.5 + 0.5
|
27 |
+
return (3e-4 - 3e-6) * mul + 3e-6
|
28 |
+
|
29 |
+
def flow2rgb(flow_map_np):
|
30 |
+
h, w, _ = flow_map_np.shape
|
31 |
+
rgb_map = np.ones((h, w, 3)).astype(np.float32)
|
32 |
+
normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
|
33 |
+
|
34 |
+
rgb_map[:, :, 0] += normalized_flow_map[:, :, 0]
|
35 |
+
rgb_map[:, :, 1] -= 0.5 * (normalized_flow_map[:, :, 0] + normalized_flow_map[:, :, 1])
|
36 |
+
rgb_map[:, :, 2] += normalized_flow_map[:, :, 1]
|
37 |
+
return rgb_map.clip(0, 1)
|
38 |
+
|
39 |
+
def train(model, local_rank):
|
40 |
+
if local_rank == 0:
|
41 |
+
writer = SummaryWriter('train')
|
42 |
+
writer_val = SummaryWriter('validate')
|
43 |
+
else:
|
44 |
+
writer = None
|
45 |
+
writer_val = None
|
46 |
+
step = 0
|
47 |
+
nr_eval = 0
|
48 |
+
dataset = VimeoDataset('train')
|
49 |
+
sampler = DistributedSampler(dataset)
|
50 |
+
train_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=8, pin_memory=True, drop_last=True, sampler=sampler)
|
51 |
+
args.step_per_epoch = train_data.__len__()
|
52 |
+
dataset_val = VimeoDataset('validation')
|
53 |
+
val_data = DataLoader(dataset_val, batch_size=16, pin_memory=True, num_workers=8)
|
54 |
+
print('training...')
|
55 |
+
time_stamp = time.time()
|
56 |
+
for epoch in range(args.epoch):
|
57 |
+
sampler.set_epoch(epoch)
|
58 |
+
for i, data in enumerate(train_data):
|
59 |
+
data_time_interval = time.time() - time_stamp
|
60 |
+
time_stamp = time.time()
|
61 |
+
data_gpu, timestep = data
|
62 |
+
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
|
63 |
+
timestep = timestep.to(device, non_blocking=True)
|
64 |
+
imgs = data_gpu[:, :6]
|
65 |
+
gt = data_gpu[:, 6:9]
|
66 |
+
learning_rate = get_learning_rate(step) * args.world_size / 4
|
67 |
+
pred, info = model.update(imgs, gt, learning_rate, training=True) # pass timestep if you are training RIFEm
|
68 |
+
train_time_interval = time.time() - time_stamp
|
69 |
+
time_stamp = time.time()
|
70 |
+
if step % 200 == 1 and local_rank == 0:
|
71 |
+
writer.add_scalar('learning_rate', learning_rate, step)
|
72 |
+
writer.add_scalar('loss/l1', info['loss_l1'], step)
|
73 |
+
writer.add_scalar('loss/tea', info['loss_tea'], step)
|
74 |
+
writer.add_scalar('loss/distill', info['loss_distill'], step)
|
75 |
+
if step % 1000 == 1 and local_rank == 0:
|
76 |
+
gt = (gt.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
|
77 |
+
mask = (torch.cat((info['mask'], info['mask_tea']), 3).permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
|
78 |
+
pred = (pred.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
|
79 |
+
merged_img = (info['merged_tea'].permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
|
80 |
+
flow0 = info['flow'].permute(0, 2, 3, 1).detach().cpu().numpy()
|
81 |
+
flow1 = info['flow_tea'].permute(0, 2, 3, 1).detach().cpu().numpy()
|
82 |
+
for i in range(5):
|
83 |
+
imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1]
|
84 |
+
writer.add_image(str(i) + '/img', imgs, step, dataformats='HWC')
|
85 |
+
writer.add_image(str(i) + '/flow', np.concatenate((flow2rgb(flow0[i]), flow2rgb(flow1[i])), 1), step, dataformats='HWC')
|
86 |
+
writer.add_image(str(i) + '/mask', mask[i], step, dataformats='HWC')
|
87 |
+
writer.flush()
|
88 |
+
if local_rank == 0:
|
89 |
+
print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_l1:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, info['loss_l1']))
|
90 |
+
step += 1
|
91 |
+
nr_eval += 1
|
92 |
+
if nr_eval % 5 == 0:
|
93 |
+
evaluate(model, val_data, step, local_rank, writer_val)
|
94 |
+
model.save_model(log_path, local_rank)
|
95 |
+
dist.barrier()
|
96 |
+
|
97 |
+
def evaluate(model, val_data, nr_eval, local_rank, writer_val):
|
98 |
+
loss_l1_list = []
|
99 |
+
loss_distill_list = []
|
100 |
+
loss_tea_list = []
|
101 |
+
psnr_list = []
|
102 |
+
psnr_list_teacher = []
|
103 |
+
time_stamp = time.time()
|
104 |
+
for i, data in enumerate(val_data):
|
105 |
+
data_gpu, timestep = data
|
106 |
+
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
|
107 |
+
imgs = data_gpu[:, :6]
|
108 |
+
gt = data_gpu[:, 6:9]
|
109 |
+
with torch.no_grad():
|
110 |
+
pred, info = model.update(imgs, gt, training=False)
|
111 |
+
merged_img = info['merged_tea']
|
112 |
+
loss_l1_list.append(info['loss_l1'].cpu().numpy())
|
113 |
+
loss_tea_list.append(info['loss_tea'].cpu().numpy())
|
114 |
+
loss_distill_list.append(info['loss_distill'].cpu().numpy())
|
115 |
+
for j in range(gt.shape[0]):
|
116 |
+
psnr = -10 * math.log10(torch.mean((gt[j] - pred[j]) * (gt[j] - pred[j])).cpu().data)
|
117 |
+
psnr_list.append(psnr)
|
118 |
+
psnr = -10 * math.log10(torch.mean((merged_img[j] - gt[j]) * (merged_img[j] - gt[j])).cpu().data)
|
119 |
+
psnr_list_teacher.append(psnr)
|
120 |
+
gt = (gt.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
|
121 |
+
pred = (pred.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
|
122 |
+
merged_img = (merged_img.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
|
123 |
+
flow0 = info['flow'].permute(0, 2, 3, 1).cpu().numpy()
|
124 |
+
flow1 = info['flow_tea'].permute(0, 2, 3, 1).cpu().numpy()
|
125 |
+
if i == 0 and local_rank == 0:
|
126 |
+
for j in range(10):
|
127 |
+
imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1]
|
128 |
+
writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC')
|
129 |
+
writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC')
|
130 |
+
|
131 |
+
eval_time_interval = time.time() - time_stamp
|
132 |
+
|
133 |
+
if local_rank != 0:
|
134 |
+
return
|
135 |
+
writer_val.add_scalar('psnr', np.array(psnr_list).mean(), nr_eval)
|
136 |
+
writer_val.add_scalar('psnr_teacher', np.array(psnr_list_teacher).mean(), nr_eval)
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
parser = argparse.ArgumentParser()
|
140 |
+
parser.add_argument('--epoch', default=300, type=int)
|
141 |
+
parser.add_argument('--batch_size', default=16, type=int, help='minibatch size')
|
142 |
+
parser.add_argument('--local_rank', default=0, type=int, help='local rank')
|
143 |
+
parser.add_argument('--world_size', default=4, type=int, help='world size')
|
144 |
+
args = parser.parse_args()
|
145 |
+
torch.distributed.init_process_group(backend="nccl", world_size=args.world_size)
|
146 |
+
torch.cuda.set_device(args.local_rank)
|
147 |
+
seed = 1234
|
148 |
+
random.seed(seed)
|
149 |
+
np.random.seed(seed)
|
150 |
+
torch.manual_seed(seed)
|
151 |
+
torch.cuda.manual_seed_all(seed)
|
152 |
+
torch.backends.cudnn.benchmark = True
|
153 |
+
model = Model(args.local_rank)
|
154 |
+
train(model, args.local_rank)
|
155 |
+
|