AItool commited on
Commit
42cbb25
·
verified ·
1 Parent(s): a43d977

Update inference_img.py

Browse files
Files changed (1) hide show
  1. inference_img.py +106 -112
inference_img.py CHANGED
@@ -1,118 +1,112 @@
1
- import os
2
- import cv2
3
- import torch
4
- import argparse
5
- from torch.nn import functional as F
6
- import warnings
7
 
8
- warnings.filterwarnings("ignore")
9
-
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- torch.set_grad_enabled(False)
12
- if torch.cuda.is_available():
13
- torch.backends.cudnn.enabled = True
14
- torch.backends.cudnn.benchmark = True
15
-
16
- parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
17
- parser.add_argument('--img', dest='img', nargs=2, required=True)
18
- parser.add_argument('--exp', default=2, type=int)
19
- parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
20
- parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
21
- parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
22
- parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
23
-
24
- args = parser.parse_args()
25
-
26
- try:
27
- from train_log.RIFE_HDv3 import Model
28
- model = Model()
29
- model.load_model(args.modelDir, -1)
30
- print("Loaded RIFE_HDv3 model.")
31
- print("Checkpoint reached!")
32
- except:
33
- from train_log.IFNet_HDv3 import Model
34
- model = Model()
35
- model.load_model(args.modelDir, -1)
36
- print("Loaded IFNet_HDv3 model.")
37
- print("Checkpoint reached!")
38
-
39
- model.eval()
40
- model.device()
41
 
42
- if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
43
- img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
44
- img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
45
- img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
46
- img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
47
-
48
- else:
49
- img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED)
50
- img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED)
51
- img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
52
- img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
53
-
54
- n, c, h, w = img0.shape
55
- ph = ((h - 1) // 32 + 1) * 32
56
- pw = ((w - 1) // 32 + 1) * 32
57
- padding = (0, pw - w, 0, ph - h)
58
- img0 = F.pad(img0, padding)
59
- img1 = F.pad(img1, padding)
60
-
61
- if args.ratio:
62
- img_list = [img0]
63
- img0_ratio = 0.0
64
- img1_ratio = 1.0
65
- if args.ratio <= img0_ratio + args.rthreshold / 2:
66
- middle = img0
67
- elif args.ratio >= img1_ratio - args.rthreshold / 2:
68
- middle = img1
69
- else:
70
- tmp_img0 = img0
71
- tmp_img1 = img1
72
- for inference_cycle in range(args.rmaxcycles):
73
- middle = model.inference(tmp_img0, tmp_img1)
74
- middle_ratio = ( img0_ratio + img1_ratio ) / 2
75
- if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2):
76
- break
77
- if args.ratio > middle_ratio:
78
- tmp_img0 = middle
79
- img0_ratio = middle_ratio
80
- else:
81
- tmp_img1 = middle
82
- img1_ratio = middle_ratio
83
- img_list.append(middle)
84
- img_list.append(img1)
85
- else:
86
- img_list = [img0, img1]
87
- for i in range(args.exp):
88
- tmp = []
89
- for j in range(len(img_list) - 1):
90
- mid = model.inference(img_list[j], img_list[j + 1])
91
- tmp.append(img_list[j])
92
- tmp.append(mid)
93
- tmp.append(img1)
94
- img_list = tmp
95
-
96
- if not os.path.exists('output'):
97
- os.mkdir('output')
98
-
99
- print("Checkpoint reached!")
100
 
101
- for i in range(len(img_list)):
102
- if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
103
- cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
 
 
 
 
104
 
105
- # Replace this line (or add below your current cv2.imwrite)
106
- save_path = 'output/img{}.png'.format(i)
107
- success = cv2.imwrite(save_path, (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
108
- print(f"Saving to {save_path} → success: {success}")
109
 
110
- else:
111
- cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
112
-
113
- # Replace this line (or add below your current cv2.imwrite)
114
- save_path = 'output/img{}.png'.format(i)
115
- success = cv2.imwrite(save_path, (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
116
- print(f"Saving to {save_path} → success: {success}")
117
-
118
  print("Checkpoint reached!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ from torch.nn import functional as F
6
+ import warnings
7
 
8
+ warnings.filterwarnings("ignore")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ torch.set_grad_enabled(False)
12
+ if torch.cuda.is_available():
13
+ torch.backends.cudnn.enabled = True
14
+ torch.backends.cudnn.benchmark = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
17
+ parser.add_argument('--img', dest='img', nargs=2, required=True)
18
+ parser.add_argument('--exp', default=2, type=int)
19
+ parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
20
+ parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
21
+ parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
22
+ parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
23
 
24
+ args = parser.parse_args()
 
 
 
25
 
26
+ try:
27
+ from train_log.RIFE_HDv3 import Model
28
+ model = Model()
29
+ model.load_model(args.modelDir, -1)
30
+ print("Loaded RIFE_HDv3 model.")
 
 
 
31
  print("Checkpoint reached!")
32
+ except:
33
+ from train_log.IFNet_HDv3 import Model
34
+ model = Model()
35
+ model.load_model(args.modelDir, -1)
36
+ print("Loaded IFNet_HDv3 model.")
37
+ print("Checkpoint reached!")
38
+
39
+ model.eval()
40
+ model.device()
41
+
42
+ if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
43
+ img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
44
+ img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
45
+ img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
46
+ img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
47
+ else:
48
+ img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED)
49
+ img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED)
50
+ img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
51
+ img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
52
+
53
+ n, c, h, w = img0.shape
54
+ ph = ((h - 1) // 32 + 1) * 32
55
+ pw = ((w - 1) // 32 + 1) * 32
56
+ padding = (0, pw - w, 0, ph - h)
57
+ img0 = F.pad(img0, padding)
58
+ img1 = F.pad(img1, padding)
59
+
60
+ if args.ratio:
61
+ img_list = [img0]
62
+ img0_ratio = 0.0
63
+ img1_ratio = 1.0
64
+ if args.ratio <= img0_ratio + args.rthreshold / 2:
65
+ middle = img0
66
+ elif args.ratio >= img1_ratio - args.rthreshold / 2:
67
+ middle = img1
68
+ else:
69
+ tmp_img0 = img0
70
+ tmp_img1 = img1
71
+ for inference_cycle in range(args.rmaxcycles):
72
+ middle = model.inference(tmp_img0, tmp_img1)
73
+ middle_ratio = (img0_ratio + img1_ratio) / 2
74
+ if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2):
75
+ break
76
+ if args.ratio > middle_ratio:
77
+ tmp_img0 = middle
78
+ img0_ratio = middle_ratio
79
+ else:
80
+ tmp_img1 = middle
81
+ img1_ratio = middle_ratio
82
+ img_list.append(middle)
83
+ img_list.append(img1)
84
+ else:
85
+ img_list = [img0, img1]
86
+ for i in range(args.exp):
87
+ tmp = []
88
+ for j in range(len(img_list) - 1):
89
+ mid = model.inference(img_list[j], img_list[j + 1])
90
+ tmp.append(img_list[j])
91
+ tmp.append(mid)
92
+ tmp.append(img1)
93
+ img_list = tmp
94
+
95
+ if not os.path.exists('output'):
96
+ os.mkdir('output')
97
+
98
+ print("Checkpoint reached!")
99
+
100
+ for i in range(len(img_list)):
101
+ if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
102
+ cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
103
+ save_path = 'output/img{}.png'.format(i)
104
+ success = cv2.imwrite(save_path, (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
105
+ print(f"Saving to {save_path} → success: {success}")
106
+ else:
107
+ cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
108
+ save_path = 'output/img{}.png'.format(i)
109
+ success = cv2.imwrite(save_path, (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
110
+ print(f"Saving to {save_path} → success: {success}")
111
+
112
+ print("Checkpoint reached!")