52Hz commited on
Commit
9ae15c5
·
1 Parent(s): 1eb1887

Update main_test_SRMNet.py

Browse files
Files changed (1) hide show
  1. main_test_SRMNet.py +17 -19
main_test_SRMNet.py CHANGED
@@ -14,24 +14,6 @@ from natsort import natsorted
14
  from model.SRMNet import SRMNet
15
  from utils import util_calculate_psnr_ssim as util
16
 
17
-
18
- def save_img(filepath, img):
19
- cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
20
-
21
-
22
- def load_checkpoint(model, weights):
23
- checkpoint = torch.load(weights)
24
- try:
25
- model.load_state_dict(checkpoint["state_dict"])
26
- except:
27
- state_dict = checkpoint["state_dict"]
28
- new_state_dict = OrderedDict()
29
- for k, v in state_dict.items():
30
- name = k[7:] # remove `module.`
31
- new_state_dict[name] = v
32
- model.load_state_dict(new_state_dict)
33
-
34
-
35
  def main():
36
  parser = argparse.ArgumentParser(description='Demo Image Denoising')
37
  parser.add_argument('--input_dir', default='test/', type=str, help='Input images')
@@ -47,7 +29,7 @@ def main():
47
 
48
  os.makedirs(out_dir, exist_ok=True)
49
 
50
- files = natsorted(glob(os.path.join(inp_dir, '*')))
51
 
52
  if len(files) == 0:
53
  raise Exception(f"No files found at {inp_dir}")
@@ -82,6 +64,22 @@ def main():
82
  save_img((os.path.join(out_dir, f + '.png')), restored)
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def setup(args):
86
  save_dir = 'result/'
87
  folder = 'test/'
 
14
  from model.SRMNet import SRMNet
15
  from utils import util_calculate_psnr_ssim as util
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def main():
18
  parser = argparse.ArgumentParser(description='Demo Image Denoising')
19
  parser.add_argument('--input_dir', default='test/', type=str, help='Input images')
 
29
 
30
  os.makedirs(out_dir, exist_ok=True)
31
 
32
+ files = natsorted(glob.glob(os.path.join(inp_dir, '*')))
33
 
34
  if len(files) == 0:
35
  raise Exception(f"No files found at {inp_dir}")
 
64
  save_img((os.path.join(out_dir, f + '.png')), restored)
65
 
66
 
67
+ def save_img(filepath, img):
68
+ cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
69
+
70
+
71
+ def load_checkpoint(model, weights):
72
+ checkpoint = torch.load(weights)
73
+ try:
74
+ model.load_state_dict(checkpoint["state_dict"])
75
+ except:
76
+ state_dict = checkpoint["state_dict"]
77
+ new_state_dict = OrderedDict()
78
+ for k, v in state_dict.items():
79
+ name = k[7:] # remove `module.`
80
+ new_state_dict[name] = v
81
+ model.load_state_dict(new_state_dict)
82
+
83
  def setup(args):
84
  save_dir = 'result/'
85
  folder = 'test/'