updated to check for existence of style transferred image
Browse files
test.py
CHANGED
@@ -19,6 +19,7 @@ parser.add_argument('--style_dir', type=str, help='Content image folder path')
|
|
19 |
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
|
20 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
21 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
|
|
22 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
|
23 |
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
24 |
args = parser.parse_args()
|
@@ -71,13 +72,12 @@ def main():
|
|
71 |
assert len(style_pths) > 0, 'Failed to load style image'
|
72 |
|
73 |
# Prepare directory for saving results
|
74 |
-
|
75 |
-
os.makedirs(out_dir, exist_ok=True)
|
76 |
|
77 |
# Load AdaIN model
|
78 |
-
vgg = torch.load('vgg_normalized.pth')
|
79 |
model = AdaINNet(vgg).to(device)
|
80 |
-
model.decoder.load_state_dict(torch.load(args.decoder_weight))
|
81 |
model.eval()
|
82 |
|
83 |
# Prepare image transform
|
@@ -95,14 +95,27 @@ def main():
|
|
95 |
|
96 |
for content_pth in content_pths:
|
97 |
content_img = Image.open(content_pth)
|
|
|
|
|
98 |
content_tensor = t(content_img).unsqueeze(0).to(device)
|
99 |
|
100 |
if args.grid_pth:
|
101 |
imgs.append(content_img)
|
102 |
|
103 |
for style_pth in style_pths:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
style_tensor = t(
|
106 |
|
107 |
# Linear Histogram Matching if needed
|
108 |
if args.color_control:
|
@@ -122,9 +135,6 @@ def main():
|
|
122 |
times.append(toc-tic)
|
123 |
|
124 |
# Save image
|
125 |
-
out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha)
|
126 |
-
if args.color_control: out_pth += '_colorcontrol'
|
127 |
-
out_pth += content_pth.suffix
|
128 |
save_image(out_tensor, out_pth)
|
129 |
|
130 |
if args.grid_pth:
|
|
|
19 |
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
|
20 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
21 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
22 |
+
parser.add_argument('--output_dir', type=str, default="results")
|
23 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
|
24 |
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
25 |
args = parser.parse_args()
|
|
|
72 |
assert len(style_pths) > 0, 'Failed to load style image'
|
73 |
|
74 |
# Prepare directory for saving results
|
75 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
76 |
|
77 |
# Load AdaIN model
|
78 |
+
vgg = torch.load('vgg_normalized.pth', weights_only=False)
|
79 |
model = AdaINNet(vgg).to(device)
|
80 |
+
model.decoder.load_state_dict(torch.load(args.decoder_weight, weights_only=False))
|
81 |
model.eval()
|
82 |
|
83 |
# Prepare image transform
|
|
|
95 |
|
96 |
for content_pth in content_pths:
|
97 |
content_img = Image.open(content_pth)
|
98 |
+
if not content_img.mode == "RGB":
|
99 |
+
content_img = content_img.convert("RGB")
|
100 |
content_tensor = t(content_img).unsqueeze(0).to(device)
|
101 |
|
102 |
if args.grid_pth:
|
103 |
imgs.append(content_img)
|
104 |
|
105 |
for style_pth in style_pths:
|
106 |
+
|
107 |
+
# check if style transferred image exists already
|
108 |
+
out_pth = os.path.join(args.output_dir, content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix)
|
109 |
+
if os.path.isfile(out_pth):
|
110 |
+
print("Skipping existing file")
|
111 |
+
continue
|
112 |
+
|
113 |
+
style_img = Image.open(style_pth)
|
114 |
+
|
115 |
+
if not style_img.mode == "RGB":
|
116 |
+
style_img = style_img.convert("RGB")
|
117 |
|
118 |
+
style_tensor = t(style_img).unsqueeze(0).to(device)
|
119 |
|
120 |
# Linear Histogram Matching if needed
|
121 |
if args.color_control:
|
|
|
135 |
times.append(toc-tic)
|
136 |
|
137 |
# Save image
|
|
|
|
|
|
|
138 |
save_image(out_tensor, out_pth)
|
139 |
|
140 |
if args.grid_pth:
|