falseu
commited on
Commit
·
4f6c34a
1
Parent(s):
0dfef01
update comments
Browse files- AdaIN.py +20 -1
- test.py +11 -3
- test_interpolate.py +6 -2
- test_video.py +6 -1
- train.py +11 -4
AdaIN.py
CHANGED
@@ -13,7 +13,11 @@ class AdaINNet(nn.Module):
|
|
13 |
def __init__(self, vgg_weight):
|
14 |
super().__init__()
|
15 |
self.encoder = vgg19(vgg_weight)
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
for parameter in self.encoder.parameters():
|
18 |
parameter.requires_grad = False
|
19 |
|
@@ -21,15 +25,29 @@ class AdaINNet(nn.Module):
|
|
21 |
|
22 |
self.mseloss = nn.MSELoss()
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def _style_loss(self, x, y):
|
25 |
return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
|
26 |
self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
|
27 |
|
28 |
def forward(self, content, style, alpha=1.0):
|
|
|
29 |
content_enc = self.encoder(content)
|
30 |
style_enc = self.encoder(style)
|
|
|
|
|
31 |
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
|
32 |
|
|
|
33 |
out = self.decoder(transfer_enc)
|
34 |
|
35 |
# vgg19 layer relu1_1
|
@@ -47,6 +65,7 @@ class AdaINNet(nn.Module):
|
|
47 |
# vgg19 layer relu4_1
|
48 |
out_enc = self.encoder[13:](out_relu31)
|
49 |
|
|
|
50 |
content_loss = self.mseloss(out_enc, transfer_enc)
|
51 |
style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
|
52 |
self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
|
|
|
13 |
def __init__(self, vgg_weight):
|
14 |
super().__init__()
|
15 |
self.encoder = vgg19(vgg_weight)
|
16 |
+
|
17 |
+
# drop layers after 4_1
|
18 |
+
self.encoder = nn.Sequential(*list(self.encoder.children())[:22])
|
19 |
+
|
20 |
+
# No optimization for encoder
|
21 |
for parameter in self.encoder.parameters():
|
22 |
parameter.requires_grad = False
|
23 |
|
|
|
25 |
|
26 |
self.mseloss = nn.MSELoss()
|
27 |
|
28 |
+
"""
|
29 |
+
Computes style loss of two images
|
30 |
+
|
31 |
+
Args:
|
32 |
+
x (torch.FloatTensor): content image tensor
|
33 |
+
y (torch.FloatTensor): style image tensor
|
34 |
+
|
35 |
+
Return:
|
36 |
+
Mean Squared Error between x.mean, y.mean and MSE between x.std, y.std
|
37 |
+
"""
|
38 |
def _style_loss(self, x, y):
|
39 |
return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
|
40 |
self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
|
41 |
|
42 |
def forward(self, content, style, alpha=1.0):
|
43 |
+
# Generate image features
|
44 |
content_enc = self.encoder(content)
|
45 |
style_enc = self.encoder(style)
|
46 |
+
|
47 |
+
# Perform style transfer on feature space
|
48 |
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
|
49 |
|
50 |
+
# Generate outptu image
|
51 |
out = self.decoder(transfer_enc)
|
52 |
|
53 |
# vgg19 layer relu1_1
|
|
|
65 |
# vgg19 layer relu4_1
|
66 |
out_enc = self.encoder[13:](out_relu31)
|
67 |
|
68 |
+
# Calculate loss
|
69 |
content_loss = self.mseloss(out_enc, transfer_enc)
|
70 |
style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
|
71 |
self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
|
test.py
CHANGED
@@ -69,6 +69,7 @@ def main():
|
|
69 |
assert len(content_pths) > 0, 'Failed to load content image'
|
70 |
assert len(style_pths) > 0, 'Failed to load style image'
|
71 |
|
|
|
72 |
out_dir = './results/'
|
73 |
os.makedirs(out_dir, exist_ok=True)
|
74 |
|
@@ -81,8 +82,9 @@ def main():
|
|
81 |
# Prepare image transform
|
82 |
t = transform(512)
|
83 |
|
84 |
-
# Prepare grid image
|
85 |
if args.grid_pth:
|
|
|
86 |
imgs = [np.ones((1, 1, 3), np.uint8) * 255]
|
87 |
for style_pth in style_pths:
|
88 |
imgs.append(Image.open(style_pth))
|
@@ -101,15 +103,20 @@ def main():
|
|
101 |
|
102 |
style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
|
103 |
|
104 |
-
|
|
|
|
|
|
|
105 |
with torch.no_grad():
|
106 |
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha).cpu()
|
107 |
|
108 |
-
|
|
|
109 |
print("Content: " + content_pth.stem + ". Style: " \
|
110 |
+ style_pth.stem + '. Alpha: ' + str(args.alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
|
111 |
times.append(toc-tic)
|
112 |
|
|
|
113 |
out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix
|
114 |
save_image(out_tensor, out_pth)
|
115 |
|
@@ -122,6 +129,7 @@ def main():
|
|
122 |
avg = sum(times)/len(times)
|
123 |
print("Average style transfer time: %.4f seconds" % (avg))
|
124 |
|
|
|
125 |
if args.grid_pth:
|
126 |
print("Generating grid image")
|
127 |
grid_image(len(content_pths) + 1, len(style_pths) + 1, imgs, save_pth=args.grid_pth)
|
|
|
69 |
assert len(content_pths) > 0, 'Failed to load content image'
|
70 |
assert len(style_pths) > 0, 'Failed to load style image'
|
71 |
|
72 |
+
# Prepare directory for saving results
|
73 |
out_dir = './results/'
|
74 |
os.makedirs(out_dir, exist_ok=True)
|
75 |
|
|
|
82 |
# Prepare image transform
|
83 |
t = transform(512)
|
84 |
|
85 |
+
# Prepare grid image, add style images to the first row
|
86 |
if args.grid_pth:
|
87 |
+
# Add empty image
|
88 |
imgs = [np.ones((1, 1, 3), np.uint8) * 255]
|
89 |
for style_pth in style_pths:
|
90 |
imgs.append(Image.open(style_pth))
|
|
|
103 |
|
104 |
style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
|
105 |
|
106 |
+
# Start time
|
107 |
+
tic = time.perf_counter()
|
108 |
+
|
109 |
+
# Execute style transfer
|
110 |
with torch.no_grad():
|
111 |
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha).cpu()
|
112 |
|
113 |
+
# End time
|
114 |
+
toc = time.perf_counter()
|
115 |
print("Content: " + content_pth.stem + ". Style: " \
|
116 |
+ style_pth.stem + '. Alpha: ' + str(args.alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
|
117 |
times.append(toc-tic)
|
118 |
|
119 |
+
# Save image
|
120 |
out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix
|
121 |
save_image(out_tensor, out_pth)
|
122 |
|
|
|
129 |
avg = sum(times)/len(times)
|
130 |
print("Average style transfer time: %.4f seconds" % (avg))
|
131 |
|
132 |
+
# Generate grid image
|
133 |
if args.grid_pth:
|
134 |
print("Generating grid image")
|
135 |
grid_image(len(content_pths) + 1, len(style_pths) + 1, imgs, save_pth=args.grid_pth)
|
test_interpolate.py
CHANGED
@@ -102,24 +102,28 @@ def main():
|
|
102 |
for content_pth in content_pths:
|
103 |
content_tensor = t(Image.open(content_pth)).unsqueeze(0).to(device)
|
104 |
|
|
|
105 |
style_tensor = []
|
106 |
for style_pth in style_pths:
|
107 |
img = Image.open(style_pth)
|
108 |
-
style_tensor.append(transform([512, 512])(img))
|
109 |
style_tensor = torch.stack(style_tensor, dim=0).to(device)
|
110 |
|
111 |
-
for inter_weight in inter_weights:
|
|
|
112 |
with torch.no_grad():
|
113 |
out_tensor = out_tensor = interpolate_style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha, inter_weight).cpu()
|
114 |
|
115 |
print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
|
116 |
|
|
|
117 |
out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight) + content_pth.suffix
|
118 |
save_image(out_tensor, out_pth)
|
119 |
|
120 |
if args.grid_pth:
|
121 |
imgs.append(Image.open(out_pth))
|
122 |
|
|
|
123 |
if args.grid_pth:
|
124 |
print("Generating grid image")
|
125 |
grid_image(5, 5, imgs, save_pth=args.grid_pth)
|
|
|
102 |
for content_pth in content_pths:
|
103 |
content_tensor = t(Image.open(content_pth)).unsqueeze(0).to(device)
|
104 |
|
105 |
+
# Prepare multiple style images
|
106 |
style_tensor = []
|
107 |
for style_pth in style_pths:
|
108 |
img = Image.open(style_pth)
|
109 |
+
style_tensor.append(transform([512, 512])(img))
|
110 |
style_tensor = torch.stack(style_tensor, dim=0).to(device)
|
111 |
|
112 |
+
for inter_weight in inter_weights:
|
113 |
+
# Execute Interpolate style transfer
|
114 |
with torch.no_grad():
|
115 |
out_tensor = out_tensor = interpolate_style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha, inter_weight).cpu()
|
116 |
|
117 |
print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
|
118 |
|
119 |
+
# Save results
|
120 |
out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight) + content_pth.suffix
|
121 |
save_image(out_tensor, out_pth)
|
122 |
|
123 |
if args.grid_pth:
|
124 |
imgs.append(Image.open(out_pth))
|
125 |
|
126 |
+
# Generate grid image
|
127 |
if args.grid_pth:
|
128 |
print("Generating grid image")
|
129 |
grid_image(5, 5, imgs, save_pth=args.grid_pth)
|
test_video.py
CHANGED
@@ -55,13 +55,16 @@ def main():
|
|
55 |
style_image_pth = Path(args.style_image)
|
56 |
style_image = Image.open(style_image_pth)
|
57 |
|
|
|
58 |
fps = int(content_video.get(cv2.CAP_PROP_FPS))
|
59 |
frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
60 |
video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
61 |
video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
62 |
|
|
|
63 |
video_tqdm = tqdm(frame_count)
|
64 |
|
|
|
65 |
out_dir = './results_video/'
|
66 |
os.makedirs(out_dir, exist_ok=True)
|
67 |
out_pth = Path(out_dir + content_video_pth.stem + '_style_' \
|
@@ -81,7 +84,8 @@ def main():
|
|
81 |
|
82 |
while content_video.isOpened():
|
83 |
ret, content_image = content_video.read()
|
84 |
-
|
|
|
85 |
break
|
86 |
|
87 |
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
|
@@ -96,6 +100,7 @@ def main():
|
|
96 |
out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
|
97 |
out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC)
|
98 |
|
|
|
99 |
writer.append_data(np.array(out_tensor))
|
100 |
video_tqdm.update(1)
|
101 |
|
|
|
55 |
style_image_pth = Path(args.style_image)
|
56 |
style_image = Image.open(style_image_pth)
|
57 |
|
58 |
+
# Read video info
|
59 |
fps = int(content_video.get(cv2.CAP_PROP_FPS))
|
60 |
frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
61 |
video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
62 |
video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
63 |
|
64 |
+
# Prepare loop
|
65 |
video_tqdm = tqdm(frame_count)
|
66 |
|
67 |
+
# Prepare output video writer
|
68 |
out_dir = './results_video/'
|
69 |
os.makedirs(out_dir, exist_ok=True)
|
70 |
out_pth = Path(out_dir + content_video_pth.stem + '_style_' \
|
|
|
84 |
|
85 |
while content_video.isOpened():
|
86 |
ret, content_image = content_video.read()
|
87 |
+
# Failed to read a frame
|
88 |
+
if not ret:
|
89 |
break
|
90 |
|
91 |
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
|
|
|
100 |
out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
|
101 |
out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC)
|
102 |
|
103 |
+
# Write output frame to video
|
104 |
writer.append_data(np.array(out_tensor))
|
105 |
video_tqdm.update(1)
|
106 |
|
train.py
CHANGED
@@ -17,21 +17,24 @@ def main():
|
|
17 |
args = parser.parse_args()
|
18 |
|
19 |
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
20 |
-
|
21 |
check_point_dir = './check_point/'
|
22 |
weights_dir = './weights/'
|
|
|
|
|
23 |
train_set = TrainSet(args.content_dir, args.style_dir)
|
24 |
train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
|
25 |
-
|
|
|
26 |
vgg_model = torch.load('vgg_normalized.pth')
|
27 |
model = AdaINNet(vgg_model).to(device)
|
28 |
-
|
29 |
decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
|
|
|
30 |
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
|
31 |
losses = []
|
32 |
iteration = 0
|
33 |
|
34 |
-
# If resume
|
35 |
if args.resume > 0:
|
36 |
states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
|
37 |
model.decoder.load_state_dict(states['decoder'])
|
@@ -54,10 +57,14 @@ def main():
|
|
54 |
content_batch = content_batch.to(device)
|
55 |
style_batch = style_batch.to(device)
|
56 |
|
|
|
57 |
loss_content, loss_style = model(content_batch, style_batch)
|
58 |
loss_scaled = loss_content + 10 * loss_style
|
|
|
|
|
59 |
loss_scaled.backward()
|
60 |
decoder_optimizer.step()
|
|
|
61 |
total_loss = loss_scaled.item()
|
62 |
content_loss = loss_content.item()
|
63 |
style_loss = loss_style.item()
|
|
|
17 |
args = parser.parse_args()
|
18 |
|
19 |
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
20 |
+
|
21 |
check_point_dir = './check_point/'
|
22 |
weights_dir = './weights/'
|
23 |
+
|
24 |
+
# Prepare Training dataset
|
25 |
train_set = TrainSet(args.content_dir, args.style_dir)
|
26 |
train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
|
27 |
+
|
28 |
+
# load vgg19 weights
|
29 |
vgg_model = torch.load('vgg_normalized.pth')
|
30 |
model = AdaINNet(vgg_model).to(device)
|
|
|
31 |
decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
|
32 |
+
|
33 |
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
|
34 |
losses = []
|
35 |
iteration = 0
|
36 |
|
37 |
+
# If resume training, load states
|
38 |
if args.resume > 0:
|
39 |
states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
|
40 |
model.decoder.load_state_dict(states['decoder'])
|
|
|
57 |
content_batch = content_batch.to(device)
|
58 |
style_batch = style_batch.to(device)
|
59 |
|
60 |
+
# Feed forward and compute loss
|
61 |
loss_content, loss_style = model(content_batch, style_batch)
|
62 |
loss_scaled = loss_content + 10 * loss_style
|
63 |
+
|
64 |
+
# Gradient descent
|
65 |
loss_scaled.backward()
|
66 |
decoder_optimizer.step()
|
67 |
+
|
68 |
total_loss = loss_scaled.item()
|
69 |
content_loss = loss_content.item()
|
70 |
style_loss = loss_style.item()
|