tidalove commited on
Commit
38886e4
·
verified ·
1 Parent(s): 03ba85a

updated to check for existence of style transferred image

Browse files
Files changed (1) hide show
  1. test.py +18 -8
test.py CHANGED
@@ -19,6 +19,7 @@ parser.add_argument('--style_dir', type=str, help='Content image folder path')
19
  parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
20
  parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
21
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
 
22
  parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
23
  parser.add_argument('--color_control', action='store_true', help='Preserve content color')
24
  args = parser.parse_args()
@@ -71,13 +72,12 @@ def main():
71
  assert len(style_pths) > 0, 'Failed to load style image'
72
 
73
  # Prepare directory for saving results
74
- out_dir = './results/'
75
- os.makedirs(out_dir, exist_ok=True)
76
 
77
  # Load AdaIN model
78
- vgg = torch.load('vgg_normalized.pth')
79
  model = AdaINNet(vgg).to(device)
80
- model.decoder.load_state_dict(torch.load(args.decoder_weight))
81
  model.eval()
82
 
83
  # Prepare image transform
@@ -95,14 +95,27 @@ def main():
95
 
96
  for content_pth in content_pths:
97
  content_img = Image.open(content_pth)
 
 
98
  content_tensor = t(content_img).unsqueeze(0).to(device)
99
 
100
  if args.grid_pth:
101
  imgs.append(content_img)
102
 
103
  for style_pth in style_pths:
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
106
 
107
  # Linear Histogram Matching if needed
108
  if args.color_control:
@@ -122,9 +135,6 @@ def main():
122
  times.append(toc-tic)
123
 
124
  # Save image
125
- out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha)
126
- if args.color_control: out_pth += '_colorcontrol'
127
- out_pth += content_pth.suffix
128
  save_image(out_tensor, out_pth)
129
 
130
  if args.grid_pth:
 
19
  parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
20
  parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
21
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
22
+ parser.add_argument('--output_dir', type=str, default="results")
23
  parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
24
  parser.add_argument('--color_control', action='store_true', help='Preserve content color')
25
  args = parser.parse_args()
 
72
  assert len(style_pths) > 0, 'Failed to load style image'
73
 
74
  # Prepare directory for saving results
75
+ os.makedirs(args.output_dir, exist_ok=True)
 
76
 
77
  # Load AdaIN model
78
+ vgg = torch.load('vgg_normalized.pth', weights_only=False)
79
  model = AdaINNet(vgg).to(device)
80
+ model.decoder.load_state_dict(torch.load(args.decoder_weight, weights_only=False))
81
  model.eval()
82
 
83
  # Prepare image transform
 
95
 
96
  for content_pth in content_pths:
97
  content_img = Image.open(content_pth)
98
+ if not content_img.mode == "RGB":
99
+ content_img = content_img.convert("RGB")
100
  content_tensor = t(content_img).unsqueeze(0).to(device)
101
 
102
  if args.grid_pth:
103
  imgs.append(content_img)
104
 
105
  for style_pth in style_pths:
106
+
107
+ # check if style transferred image exists already
108
+ out_pth = os.path.join(args.output_dir, content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix)
109
+ if os.path.isfile(out_pth):
110
+ print("Skipping existing file")
111
+ continue
112
+
113
+ style_img = Image.open(style_pth)
114
+
115
+ if not style_img.mode == "RGB":
116
+ style_img = style_img.convert("RGB")
117
 
118
+ style_tensor = t(style_img).unsqueeze(0).to(device)
119
 
120
  # Linear Histogram Matching if needed
121
  if args.color_control:
 
135
  times.append(toc-tic)
136
 
137
  # Save image
 
 
 
138
  save_image(out_tensor, out_pth)
139
 
140
  if args.grid_pth: