AItool commited on
Commit
ffad2e7
·
verified ·
1 Parent(s): 801e00f

Upload 4 files

Browse files
Files changed (4) hide show
  1. dataset.py +109 -0
  2. inference_img.py +111 -0
  3. inference_video.py +297 -0
  4. train.py +155 -0
dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import ast
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ from torch.utils.data import DataLoader, Dataset
8
+
9
+ cv2.setNumThreads(1)
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ class VimeoDataset(Dataset):
12
+ def __init__(self, dataset_name, batch_size=32):
13
+ self.batch_size = batch_size
14
+ self.dataset_name = dataset_name
15
+ self.h = 256
16
+ self.w = 448
17
+ self.data_root = 'vimeo_triplet'
18
+ self.image_root = os.path.join(self.data_root, 'sequences')
19
+ train_fn = os.path.join(self.data_root, 'tri_trainlist.txt')
20
+ test_fn = os.path.join(self.data_root, 'tri_testlist.txt')
21
+ with open(train_fn, 'r') as f:
22
+ self.trainlist = f.read().splitlines()
23
+ with open(test_fn, 'r') as f:
24
+ self.testlist = f.read().splitlines()
25
+ self.load_data()
26
+
27
+ def __len__(self):
28
+ return len(self.meta_data)
29
+
30
+ def load_data(self):
31
+ cnt = int(len(self.trainlist) * 0.95)
32
+ if self.dataset_name == 'train':
33
+ self.meta_data = self.trainlist[:cnt]
34
+ elif self.dataset_name == 'test':
35
+ self.meta_data = self.testlist
36
+ else:
37
+ self.meta_data = self.trainlist[cnt:]
38
+
39
+ def crop(self, img0, gt, img1, h, w):
40
+ ih, iw, _ = img0.shape
41
+ x = np.random.randint(0, ih - h + 1)
42
+ y = np.random.randint(0, iw - w + 1)
43
+ img0 = img0[x:x+h, y:y+w, :]
44
+ img1 = img1[x:x+h, y:y+w, :]
45
+ gt = gt[x:x+h, y:y+w, :]
46
+ return img0, gt, img1
47
+
48
+ def getimg(self, index):
49
+ imgpath = os.path.join(self.image_root, self.meta_data[index])
50
+ imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png']
51
+
52
+ # Load images
53
+ img0 = cv2.imread(imgpaths[0])
54
+ gt = cv2.imread(imgpaths[1])
55
+ img1 = cv2.imread(imgpaths[2])
56
+ timestep = 0.5
57
+ return img0, gt, img1, timestep
58
+
59
+ # RIFEm with Vimeo-Septuplet
60
+ # imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png', imgpath + '/im4.png', imgpath + '/im5.png', imgpath + '/im6.png', imgpath + '/im7.png']
61
+ # ind = [0, 1, 2, 3, 4, 5, 6]
62
+ # random.shuffle(ind)
63
+ # ind = ind[:3]
64
+ # ind.sort()
65
+ # img0 = cv2.imread(imgpaths[ind[0]])
66
+ # gt = cv2.imread(imgpaths[ind[1]])
67
+ # img1 = cv2.imread(imgpaths[ind[2]])
68
+ # timestep = (ind[1] - ind[0]) * 1.0 / (ind[2] - ind[0] + 1e-6)
69
+
70
+ def __getitem__(self, index):
71
+ img0, gt, img1, timestep = self.getimg(index)
72
+ if self.dataset_name == 'train':
73
+ img0, gt, img1 = self.crop(img0, gt, img1, 224, 224)
74
+ if random.uniform(0, 1) < 0.5:
75
+ img0 = img0[:, :, ::-1]
76
+ img1 = img1[:, :, ::-1]
77
+ gt = gt[:, :, ::-1]
78
+ if random.uniform(0, 1) < 0.5:
79
+ img0 = img0[::-1]
80
+ img1 = img1[::-1]
81
+ gt = gt[::-1]
82
+ if random.uniform(0, 1) < 0.5:
83
+ img0 = img0[:, ::-1]
84
+ img1 = img1[:, ::-1]
85
+ gt = gt[:, ::-1]
86
+ if random.uniform(0, 1) < 0.5:
87
+ tmp = img1
88
+ img1 = img0
89
+ img0 = tmp
90
+ timestep = 1 - timestep
91
+ # random rotation
92
+ p = random.uniform(0, 1)
93
+ if p < 0.25:
94
+ img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE)
95
+ gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE)
96
+ img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE)
97
+ elif p < 0.5:
98
+ img0 = cv2.rotate(img0, cv2.ROTATE_180)
99
+ gt = cv2.rotate(gt, cv2.ROTATE_180)
100
+ img1 = cv2.rotate(img1, cv2.ROTATE_180)
101
+ elif p < 0.75:
102
+ img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE)
103
+ gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE)
104
+ img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE)
105
+ img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
106
+ img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
107
+ gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
108
+ timestep = torch.tensor(timestep).reshape(1, 1, 1)
109
+ return torch.cat((img0, img1, gt), 0), timestep
inference_img.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ from torch.nn import functional as F
6
+ import warnings
7
+ warnings.filterwarnings("ignore")
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ torch.set_grad_enabled(False)
11
+ if torch.cuda.is_available():
12
+ torch.backends.cudnn.enabled = True
13
+ torch.backends.cudnn.benchmark = True
14
+
15
+ parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
16
+ parser.add_argument('--img', dest='img', nargs=2, required=True)
17
+ parser.add_argument('--exp', default=4, type=int)
18
+ parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
19
+ parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
20
+ parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
21
+ parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
22
+
23
+ args = parser.parse_args()
24
+
25
+ try:
26
+ try:
27
+ try:
28
+ from model.RIFE_HDv2 import Model
29
+ model = Model()
30
+ model.load_model(args.modelDir, -1)
31
+ print("Loaded v2.x HD model.")
32
+ except:
33
+ from train_log.RIFE_HDv3 import Model
34
+ model = Model()
35
+ model.load_model(args.modelDir, -1)
36
+ print("Loaded v3.x HD model.")
37
+ except:
38
+ from model.RIFE_HD import Model
39
+ model = Model()
40
+ model.load_model(args.modelDir, -1)
41
+ print("Loaded v1.x HD model")
42
+ except:
43
+ from model.RIFE import Model
44
+ model = Model()
45
+ model.load_model(args.modelDir, -1)
46
+ print("Loaded ArXiv-RIFE model")
47
+ model.eval()
48
+ model.device()
49
+
50
+ if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
51
+ img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
52
+ img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
53
+ img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
54
+ img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
55
+
56
+ else:
57
+ img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED)
58
+ img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED)
59
+ img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
60
+ img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
61
+
62
+ n, c, h, w = img0.shape
63
+ ph = ((h - 1) // 32 + 1) * 32
64
+ pw = ((w - 1) // 32 + 1) * 32
65
+ padding = (0, pw - w, 0, ph - h)
66
+ img0 = F.pad(img0, padding)
67
+ img1 = F.pad(img1, padding)
68
+
69
+
70
+ if args.ratio:
71
+ img_list = [img0]
72
+ img0_ratio = 0.0
73
+ img1_ratio = 1.0
74
+ if args.ratio <= img0_ratio + args.rthreshold / 2:
75
+ middle = img0
76
+ elif args.ratio >= img1_ratio - args.rthreshold / 2:
77
+ middle = img1
78
+ else:
79
+ tmp_img0 = img0
80
+ tmp_img1 = img1
81
+ for inference_cycle in range(args.rmaxcycles):
82
+ middle = model.inference(tmp_img0, tmp_img1)
83
+ middle_ratio = ( img0_ratio + img1_ratio ) / 2
84
+ if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2):
85
+ break
86
+ if args.ratio > middle_ratio:
87
+ tmp_img0 = middle
88
+ img0_ratio = middle_ratio
89
+ else:
90
+ tmp_img1 = middle
91
+ img1_ratio = middle_ratio
92
+ img_list.append(middle)
93
+ img_list.append(img1)
94
+ else:
95
+ img_list = [img0, img1]
96
+ for i in range(args.exp):
97
+ tmp = []
98
+ for j in range(len(img_list) - 1):
99
+ mid = model.inference(img_list[j], img_list[j + 1])
100
+ tmp.append(img_list[j])
101
+ tmp.append(mid)
102
+ tmp.append(img1)
103
+ img_list = tmp
104
+
105
+ if not os.path.exists('output'):
106
+ os.mkdir('output')
107
+ for i in range(len(img_list)):
108
+ if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
109
+ cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
110
+ else:
111
+ cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
inference_video.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from torch.nn import functional as F
8
+ import warnings
9
+ import _thread
10
+ import skvideo.io
11
+ from queue import Queue, Empty
12
+ from model.pytorch_msssim import ssim_matlab
13
+
14
+ warnings.filterwarnings("ignore")
15
+
16
+ def transferAudio(sourceVideo, targetVideo):
17
+ import shutil
18
+ import moviepy.editor
19
+ tempAudioFileName = "./temp/audio.mkv"
20
+
21
+ # split audio from original video file and store in "temp" directory
22
+ if True:
23
+
24
+ # clear old "temp" directory if it exits
25
+ if os.path.isdir("temp"):
26
+ # remove temp directory
27
+ shutil.rmtree("temp")
28
+ # create new "temp" directory
29
+ os.makedirs("temp")
30
+ # extract audio from video
31
+ os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName))
32
+
33
+ targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
34
+ os.rename(targetVideo, targetNoAudio)
35
+ # combine audio file and new video file
36
+ os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
37
+
38
+ if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
39
+ tempAudioFileName = "./temp/audio.m4a"
40
+ os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName))
41
+ os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
42
+ if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
43
+ os.rename(targetNoAudio, targetVideo)
44
+ print("Audio transfer failed. Interpolated video will have no audio")
45
+ else:
46
+ print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
47
+
48
+ # remove audio-less video
49
+ os.remove(targetNoAudio)
50
+ else:
51
+ os.remove(targetNoAudio)
52
+
53
+ # remove temp directory
54
+ shutil.rmtree("temp")
55
+
56
+ parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
57
+ parser.add_argument('--video', dest='video', type=str, default=None)
58
+ parser.add_argument('--output', dest='output', type=str, default=None)
59
+ parser.add_argument('--img', dest='img', type=str, default=None)
60
+ parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
61
+ parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
62
+ parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
63
+ parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
64
+ parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video')
65
+ parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
66
+ parser.add_argument('--fps', dest='fps', type=int, default=None)
67
+ parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
68
+ parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension')
69
+ parser.add_argument('--exp', dest='exp', type=int, default=1)
70
+ args = parser.parse_args()
71
+ assert (not args.video is None or not args.img is None)
72
+ if args.skip:
73
+ print("skip flag is abandoned, please refer to issue #207.")
74
+ if args.UHD and args.scale==1.0:
75
+ args.scale = 0.5
76
+ assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
77
+ if not args.img is None:
78
+ args.png = True
79
+
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ torch.set_grad_enabled(False)
82
+ if torch.cuda.is_available():
83
+ torch.backends.cudnn.enabled = True
84
+ torch.backends.cudnn.benchmark = True
85
+ if(args.fp16):
86
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
87
+
88
+ try:
89
+ try:
90
+ try:
91
+ from model.RIFE_HDv2 import Model
92
+ model = Model()
93
+ model.load_model(args.modelDir, -1)
94
+ print("Loaded v2.x HD model.")
95
+ except:
96
+ from train_log.RIFE_HDv3 import Model
97
+ model = Model()
98
+ model.load_model(args.modelDir, -1)
99
+ print("Loaded v3.x HD model.")
100
+ except:
101
+ from model.RIFE_HD import Model
102
+ model = Model()
103
+ model.load_model(args.modelDir, -1)
104
+ print("Loaded v1.x HD model")
105
+ except:
106
+ from model.RIFE import Model
107
+ model = Model()
108
+ model.load_model(args.modelDir, -1)
109
+ print("Loaded ArXiv-RIFE model")
110
+ model.eval()
111
+ model.device()
112
+
113
+ if not args.video is None:
114
+ videoCapture = cv2.VideoCapture(args.video)
115
+ fps = videoCapture.get(cv2.CAP_PROP_FPS)
116
+ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
117
+ videoCapture.release()
118
+ if args.fps is None:
119
+ fpsNotAssigned = True
120
+ args.fps = fps * (2 ** args.exp)
121
+ else:
122
+ fpsNotAssigned = False
123
+ videogen = skvideo.io.vreader(args.video)
124
+ lastframe = next(videogen)
125
+ fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
126
+ video_path_wo_ext, ext = os.path.splitext(args.video)
127
+ print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps))
128
+ if args.png == False and fpsNotAssigned == True:
129
+ print("The audio will be merged after interpolation process")
130
+ else:
131
+ print("Will not merge audio because using png or fps flag!")
132
+ else:
133
+ videogen = []
134
+ for f in os.listdir(args.img):
135
+ if 'png' in f:
136
+ videogen.append(f)
137
+ tot_frame = len(videogen)
138
+ videogen.sort(key= lambda x:int(x[:-4]))
139
+ lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
140
+ videogen = videogen[1:]
141
+ h, w, _ = lastframe.shape
142
+ vid_out_name = None
143
+ vid_out = None
144
+ if args.png:
145
+ if not os.path.exists('vid_out'):
146
+ os.mkdir('vid_out')
147
+ else:
148
+ if args.output is not None:
149
+ vid_out_name = args.output
150
+ else:
151
+ vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext)
152
+ vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))
153
+
154
+ def clear_write_buffer(user_args, write_buffer):
155
+ cnt = 0
156
+ while True:
157
+ item = write_buffer.get()
158
+ if item is None:
159
+ break
160
+ if user_args.png:
161
+ cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
162
+ cnt += 1
163
+ else:
164
+ vid_out.write(item[:, :, ::-1])
165
+
166
+ def build_read_buffer(user_args, read_buffer, videogen):
167
+ try:
168
+ for frame in videogen:
169
+ if not user_args.img is None:
170
+ frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
171
+ if user_args.montage:
172
+ frame = frame[:, left: left + w]
173
+ read_buffer.put(frame)
174
+ except:
175
+ pass
176
+ read_buffer.put(None)
177
+
178
+ def make_inference(I0, I1, n):
179
+ global model
180
+ middle = model.inference(I0, I1, args.scale)
181
+ if n == 1:
182
+ return [middle]
183
+ first_half = make_inference(I0, middle, n=n//2)
184
+ second_half = make_inference(middle, I1, n=n//2)
185
+ if n%2:
186
+ return [*first_half, middle, *second_half]
187
+ else:
188
+ return [*first_half, *second_half]
189
+
190
+ def pad_image(img):
191
+ if(args.fp16):
192
+ return F.pad(img, padding).half()
193
+ else:
194
+ return F.pad(img, padding)
195
+
196
+ if args.montage:
197
+ left = w // 4
198
+ w = w // 2
199
+ tmp = max(32, int(32 / args.scale))
200
+ ph = ((h - 1) // tmp + 1) * tmp
201
+ pw = ((w - 1) // tmp + 1) * tmp
202
+ padding = (0, pw - w, 0, ph - h)
203
+ pbar = tqdm(total=tot_frame)
204
+ if args.montage:
205
+ lastframe = lastframe[:, left: left + w]
206
+ write_buffer = Queue(maxsize=500)
207
+ read_buffer = Queue(maxsize=500)
208
+ _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
209
+ _thread.start_new_thread(clear_write_buffer, (args, write_buffer))
210
+
211
+ I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
212
+ I1 = pad_image(I1)
213
+ temp = None # save lastframe when processing static frame
214
+
215
+ while True:
216
+ if temp is not None:
217
+ frame = temp
218
+ temp = None
219
+ else:
220
+ frame = read_buffer.get()
221
+ if frame is None:
222
+ break
223
+ I0 = I1
224
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
225
+ I1 = pad_image(I1)
226
+ I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
227
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
228
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
229
+
230
+ break_flag = False
231
+ if ssim > 0.996:
232
+ frame = read_buffer.get() # read a new frame
233
+ if frame is None:
234
+ break_flag = True
235
+ frame = lastframe
236
+ else:
237
+ temp = frame
238
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
239
+ I1 = pad_image(I1)
240
+ I1 = model.inference(I0, I1, args.scale)
241
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
242
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
243
+ frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
244
+
245
+ if ssim < 0.2:
246
+ output = []
247
+ for i in range((2 ** args.exp) - 1):
248
+ output.append(I0)
249
+ '''
250
+ output = []
251
+ step = 1 / (2 ** args.exp)
252
+ alpha = 0
253
+ for i in range((2 ** args.exp) - 1):
254
+ alpha += step
255
+ beta = 1-alpha
256
+ output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
257
+ '''
258
+ else:
259
+ output = make_inference(I0, I1, 2**args.exp-1) if args.exp else []
260
+
261
+ if args.montage:
262
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
263
+ for mid in output:
264
+ mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
265
+ write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
266
+ else:
267
+ write_buffer.put(lastframe)
268
+ for mid in output:
269
+ mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
270
+ write_buffer.put(mid[:h, :w])
271
+ pbar.update(1)
272
+ lastframe = frame
273
+ if break_flag:
274
+ break
275
+
276
+ if args.montage:
277
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
278
+ else:
279
+ write_buffer.put(lastframe)
280
+
281
+ write_buffer.put(None)
282
+
283
+ import time
284
+ while(not write_buffer.empty()):
285
+ time.sleep(0.1)
286
+ pbar.close()
287
+ if not vid_out is None:
288
+ vid_out.release()
289
+
290
+ # move audio to new video file if appropriate
291
+ if args.png == False and fpsNotAssigned == True and not args.video is None:
292
+ try:
293
+ transferAudio(args.video, vid_out_name)
294
+ except:
295
+ print("Audio transfer failed. Interpolated video will have no audio")
296
+ targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
297
+ os.rename(targetNoAudio, vid_out_name)
train.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import time
5
+ import torch
6
+ import torch.distributed as dist
7
+ import numpy as np
8
+ import random
9
+ import argparse
10
+
11
+ from model.RIFE import Model
12
+ from dataset import *
13
+ from torch.utils.data import DataLoader, Dataset
14
+ from torch.utils.tensorboard import SummaryWriter
15
+ from torch.utils.data.distributed import DistributedSampler
16
+
17
+ device = torch.device("cuda")
18
+
19
+ log_path = 'train_log'
20
+
21
+ def get_learning_rate(step):
22
+ if step < 2000:
23
+ mul = step / 2000.
24
+ return 3e-4 * mul
25
+ else:
26
+ mul = np.cos((step - 2000) / (args.epoch * args.step_per_epoch - 2000.) * math.pi) * 0.5 + 0.5
27
+ return (3e-4 - 3e-6) * mul + 3e-6
28
+
29
+ def flow2rgb(flow_map_np):
30
+ h, w, _ = flow_map_np.shape
31
+ rgb_map = np.ones((h, w, 3)).astype(np.float32)
32
+ normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
33
+
34
+ rgb_map[:, :, 0] += normalized_flow_map[:, :, 0]
35
+ rgb_map[:, :, 1] -= 0.5 * (normalized_flow_map[:, :, 0] + normalized_flow_map[:, :, 1])
36
+ rgb_map[:, :, 2] += normalized_flow_map[:, :, 1]
37
+ return rgb_map.clip(0, 1)
38
+
39
+ def train(model, local_rank):
40
+ if local_rank == 0:
41
+ writer = SummaryWriter('train')
42
+ writer_val = SummaryWriter('validate')
43
+ else:
44
+ writer = None
45
+ writer_val = None
46
+ step = 0
47
+ nr_eval = 0
48
+ dataset = VimeoDataset('train')
49
+ sampler = DistributedSampler(dataset)
50
+ train_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=8, pin_memory=True, drop_last=True, sampler=sampler)
51
+ args.step_per_epoch = train_data.__len__()
52
+ dataset_val = VimeoDataset('validation')
53
+ val_data = DataLoader(dataset_val, batch_size=16, pin_memory=True, num_workers=8)
54
+ print('training...')
55
+ time_stamp = time.time()
56
+ for epoch in range(args.epoch):
57
+ sampler.set_epoch(epoch)
58
+ for i, data in enumerate(train_data):
59
+ data_time_interval = time.time() - time_stamp
60
+ time_stamp = time.time()
61
+ data_gpu, timestep = data
62
+ data_gpu = data_gpu.to(device, non_blocking=True) / 255.
63
+ timestep = timestep.to(device, non_blocking=True)
64
+ imgs = data_gpu[:, :6]
65
+ gt = data_gpu[:, 6:9]
66
+ learning_rate = get_learning_rate(step) * args.world_size / 4
67
+ pred, info = model.update(imgs, gt, learning_rate, training=True) # pass timestep if you are training RIFEm
68
+ train_time_interval = time.time() - time_stamp
69
+ time_stamp = time.time()
70
+ if step % 200 == 1 and local_rank == 0:
71
+ writer.add_scalar('learning_rate', learning_rate, step)
72
+ writer.add_scalar('loss/l1', info['loss_l1'], step)
73
+ writer.add_scalar('loss/tea', info['loss_tea'], step)
74
+ writer.add_scalar('loss/distill', info['loss_distill'], step)
75
+ if step % 1000 == 1 and local_rank == 0:
76
+ gt = (gt.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
77
+ mask = (torch.cat((info['mask'], info['mask_tea']), 3).permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
78
+ pred = (pred.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
79
+ merged_img = (info['merged_tea'].permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
80
+ flow0 = info['flow'].permute(0, 2, 3, 1).detach().cpu().numpy()
81
+ flow1 = info['flow_tea'].permute(0, 2, 3, 1).detach().cpu().numpy()
82
+ for i in range(5):
83
+ imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1]
84
+ writer.add_image(str(i) + '/img', imgs, step, dataformats='HWC')
85
+ writer.add_image(str(i) + '/flow', np.concatenate((flow2rgb(flow0[i]), flow2rgb(flow1[i])), 1), step, dataformats='HWC')
86
+ writer.add_image(str(i) + '/mask', mask[i], step, dataformats='HWC')
87
+ writer.flush()
88
+ if local_rank == 0:
89
+ print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_l1:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, info['loss_l1']))
90
+ step += 1
91
+ nr_eval += 1
92
+ if nr_eval % 5 == 0:
93
+ evaluate(model, val_data, step, local_rank, writer_val)
94
+ model.save_model(log_path, local_rank)
95
+ dist.barrier()
96
+
97
+ def evaluate(model, val_data, nr_eval, local_rank, writer_val):
98
+ loss_l1_list = []
99
+ loss_distill_list = []
100
+ loss_tea_list = []
101
+ psnr_list = []
102
+ psnr_list_teacher = []
103
+ time_stamp = time.time()
104
+ for i, data in enumerate(val_data):
105
+ data_gpu, timestep = data
106
+ data_gpu = data_gpu.to(device, non_blocking=True) / 255.
107
+ imgs = data_gpu[:, :6]
108
+ gt = data_gpu[:, 6:9]
109
+ with torch.no_grad():
110
+ pred, info = model.update(imgs, gt, training=False)
111
+ merged_img = info['merged_tea']
112
+ loss_l1_list.append(info['loss_l1'].cpu().numpy())
113
+ loss_tea_list.append(info['loss_tea'].cpu().numpy())
114
+ loss_distill_list.append(info['loss_distill'].cpu().numpy())
115
+ for j in range(gt.shape[0]):
116
+ psnr = -10 * math.log10(torch.mean((gt[j] - pred[j]) * (gt[j] - pred[j])).cpu().data)
117
+ psnr_list.append(psnr)
118
+ psnr = -10 * math.log10(torch.mean((merged_img[j] - gt[j]) * (merged_img[j] - gt[j])).cpu().data)
119
+ psnr_list_teacher.append(psnr)
120
+ gt = (gt.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
121
+ pred = (pred.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
122
+ merged_img = (merged_img.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
123
+ flow0 = info['flow'].permute(0, 2, 3, 1).cpu().numpy()
124
+ flow1 = info['flow_tea'].permute(0, 2, 3, 1).cpu().numpy()
125
+ if i == 0 and local_rank == 0:
126
+ for j in range(10):
127
+ imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1]
128
+ writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC')
129
+ writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC')
130
+
131
+ eval_time_interval = time.time() - time_stamp
132
+
133
+ if local_rank != 0:
134
+ return
135
+ writer_val.add_scalar('psnr', np.array(psnr_list).mean(), nr_eval)
136
+ writer_val.add_scalar('psnr_teacher', np.array(psnr_list_teacher).mean(), nr_eval)
137
+
138
+ if __name__ == "__main__":
139
+ parser = argparse.ArgumentParser()
140
+ parser.add_argument('--epoch', default=300, type=int)
141
+ parser.add_argument('--batch_size', default=16, type=int, help='minibatch size')
142
+ parser.add_argument('--local_rank', default=0, type=int, help='local rank')
143
+ parser.add_argument('--world_size', default=4, type=int, help='world size')
144
+ args = parser.parse_args()
145
+ torch.distributed.init_process_group(backend="nccl", world_size=args.world_size)
146
+ torch.cuda.set_device(args.local_rank)
147
+ seed = 1234
148
+ random.seed(seed)
149
+ np.random.seed(seed)
150
+ torch.manual_seed(seed)
151
+ torch.cuda.manual_seed_all(seed)
152
+ torch.backends.cudnn.benchmark = True
153
+ model = Model(args.local_rank)
154
+ train(model, args.local_rank)
155
+