falseu commited on
Commit
4f6c34a
·
1 Parent(s): 0dfef01

update comments

Browse files
Files changed (5) hide show
  1. AdaIN.py +20 -1
  2. test.py +11 -3
  3. test_interpolate.py +6 -2
  4. test_video.py +6 -1
  5. train.py +11 -4
AdaIN.py CHANGED
@@ -13,7 +13,11 @@ class AdaINNet(nn.Module):
13
  def __init__(self, vgg_weight):
14
  super().__init__()
15
  self.encoder = vgg19(vgg_weight)
16
- self.encoder = nn.Sequential(*list(self.encoder.children())[:22]) # drop layers after 4_1
 
 
 
 
17
  for parameter in self.encoder.parameters():
18
  parameter.requires_grad = False
19
 
@@ -21,15 +25,29 @@ class AdaINNet(nn.Module):
21
 
22
  self.mseloss = nn.MSELoss()
23
 
 
 
 
 
 
 
 
 
 
 
24
  def _style_loss(self, x, y):
25
  return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
26
  self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
27
 
28
  def forward(self, content, style, alpha=1.0):
 
29
  content_enc = self.encoder(content)
30
  style_enc = self.encoder(style)
 
 
31
  transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
32
 
 
33
  out = self.decoder(transfer_enc)
34
 
35
  # vgg19 layer relu1_1
@@ -47,6 +65,7 @@ class AdaINNet(nn.Module):
47
  # vgg19 layer relu4_1
48
  out_enc = self.encoder[13:](out_relu31)
49
 
 
50
  content_loss = self.mseloss(out_enc, transfer_enc)
51
  style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
52
  self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
 
13
  def __init__(self, vgg_weight):
14
  super().__init__()
15
  self.encoder = vgg19(vgg_weight)
16
+
17
+ # drop layers after 4_1
18
+ self.encoder = nn.Sequential(*list(self.encoder.children())[:22])
19
+
20
+ # No optimization for encoder
21
  for parameter in self.encoder.parameters():
22
  parameter.requires_grad = False
23
 
 
25
 
26
  self.mseloss = nn.MSELoss()
27
 
28
+ """
29
+ Computes style loss of two images
30
+
31
+ Args:
32
+ x (torch.FloatTensor): content image tensor
33
+ y (torch.FloatTensor): style image tensor
34
+
35
+ Return:
36
+ Mean Squared Error between x.mean, y.mean and MSE between x.std, y.std
37
+ """
38
  def _style_loss(self, x, y):
39
  return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
40
  self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
41
 
42
  def forward(self, content, style, alpha=1.0):
43
+ # Generate image features
44
  content_enc = self.encoder(content)
45
  style_enc = self.encoder(style)
46
+
47
+ # Perform style transfer on feature space
48
  transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
49
 
50
+ # Generate outptu image
51
  out = self.decoder(transfer_enc)
52
 
53
  # vgg19 layer relu1_1
 
65
  # vgg19 layer relu4_1
66
  out_enc = self.encoder[13:](out_relu31)
67
 
68
+ # Calculate loss
69
  content_loss = self.mseloss(out_enc, transfer_enc)
70
  style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
71
  self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
test.py CHANGED
@@ -69,6 +69,7 @@ def main():
69
  assert len(content_pths) > 0, 'Failed to load content image'
70
  assert len(style_pths) > 0, 'Failed to load style image'
71
 
 
72
  out_dir = './results/'
73
  os.makedirs(out_dir, exist_ok=True)
74
 
@@ -81,8 +82,9 @@ def main():
81
  # Prepare image transform
82
  t = transform(512)
83
 
84
- # Prepare grid image
85
  if args.grid_pth:
 
86
  imgs = [np.ones((1, 1, 3), np.uint8) * 255]
87
  for style_pth in style_pths:
88
  imgs.append(Image.open(style_pth))
@@ -101,15 +103,20 @@ def main():
101
 
102
  style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
103
 
104
- tic = time.perf_counter() # Start time
 
 
 
105
  with torch.no_grad():
106
  out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha).cpu()
107
 
108
- toc = time.perf_counter() # End time
 
109
  print("Content: " + content_pth.stem + ". Style: " \
110
  + style_pth.stem + '. Alpha: ' + str(args.alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
111
  times.append(toc-tic)
112
 
 
113
  out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix
114
  save_image(out_tensor, out_pth)
115
 
@@ -122,6 +129,7 @@ def main():
122
  avg = sum(times)/len(times)
123
  print("Average style transfer time: %.4f seconds" % (avg))
124
 
 
125
  if args.grid_pth:
126
  print("Generating grid image")
127
  grid_image(len(content_pths) + 1, len(style_pths) + 1, imgs, save_pth=args.grid_pth)
 
69
  assert len(content_pths) > 0, 'Failed to load content image'
70
  assert len(style_pths) > 0, 'Failed to load style image'
71
 
72
+ # Prepare directory for saving results
73
  out_dir = './results/'
74
  os.makedirs(out_dir, exist_ok=True)
75
 
 
82
  # Prepare image transform
83
  t = transform(512)
84
 
85
+ # Prepare grid image, add style images to the first row
86
  if args.grid_pth:
87
+ # Add empty image
88
  imgs = [np.ones((1, 1, 3), np.uint8) * 255]
89
  for style_pth in style_pths:
90
  imgs.append(Image.open(style_pth))
 
103
 
104
  style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
105
 
106
+ # Start time
107
+ tic = time.perf_counter()
108
+
109
+ # Execute style transfer
110
  with torch.no_grad():
111
  out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha).cpu()
112
 
113
+ # End time
114
+ toc = time.perf_counter()
115
  print("Content: " + content_pth.stem + ". Style: " \
116
  + style_pth.stem + '. Alpha: ' + str(args.alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
117
  times.append(toc-tic)
118
 
119
+ # Save image
120
  out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix
121
  save_image(out_tensor, out_pth)
122
 
 
129
  avg = sum(times)/len(times)
130
  print("Average style transfer time: %.4f seconds" % (avg))
131
 
132
+ # Generate grid image
133
  if args.grid_pth:
134
  print("Generating grid image")
135
  grid_image(len(content_pths) + 1, len(style_pths) + 1, imgs, save_pth=args.grid_pth)
test_interpolate.py CHANGED
@@ -102,24 +102,28 @@ def main():
102
  for content_pth in content_pths:
103
  content_tensor = t(Image.open(content_pth)).unsqueeze(0).to(device)
104
 
 
105
  style_tensor = []
106
  for style_pth in style_pths:
107
  img = Image.open(style_pth)
108
- style_tensor.append(transform([512, 512])(img)) # Convert style images to same size
109
  style_tensor = torch.stack(style_tensor, dim=0).to(device)
110
 
111
- for inter_weight in inter_weights:
 
112
  with torch.no_grad():
113
  out_tensor = out_tensor = interpolate_style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha, inter_weight).cpu()
114
 
115
  print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
116
 
 
117
  out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight) + content_pth.suffix
118
  save_image(out_tensor, out_pth)
119
 
120
  if args.grid_pth:
121
  imgs.append(Image.open(out_pth))
122
 
 
123
  if args.grid_pth:
124
  print("Generating grid image")
125
  grid_image(5, 5, imgs, save_pth=args.grid_pth)
 
102
  for content_pth in content_pths:
103
  content_tensor = t(Image.open(content_pth)).unsqueeze(0).to(device)
104
 
105
+ # Prepare multiple style images
106
  style_tensor = []
107
  for style_pth in style_pths:
108
  img = Image.open(style_pth)
109
+ style_tensor.append(transform([512, 512])(img))
110
  style_tensor = torch.stack(style_tensor, dim=0).to(device)
111
 
112
+ for inter_weight in inter_weights:
113
+ # Execute Interpolate style transfer
114
  with torch.no_grad():
115
  out_tensor = out_tensor = interpolate_style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha, inter_weight).cpu()
116
 
117
  print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
118
 
119
+ # Save results
120
  out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight) + content_pth.suffix
121
  save_image(out_tensor, out_pth)
122
 
123
  if args.grid_pth:
124
  imgs.append(Image.open(out_pth))
125
 
126
+ # Generate grid image
127
  if args.grid_pth:
128
  print("Generating grid image")
129
  grid_image(5, 5, imgs, save_pth=args.grid_pth)
test_video.py CHANGED
@@ -55,13 +55,16 @@ def main():
55
  style_image_pth = Path(args.style_image)
56
  style_image = Image.open(style_image_pth)
57
 
 
58
  fps = int(content_video.get(cv2.CAP_PROP_FPS))
59
  frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT))
60
  video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
61
  video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH))
62
 
 
63
  video_tqdm = tqdm(frame_count)
64
 
 
65
  out_dir = './results_video/'
66
  os.makedirs(out_dir, exist_ok=True)
67
  out_pth = Path(out_dir + content_video_pth.stem + '_style_' \
@@ -81,7 +84,8 @@ def main():
81
 
82
  while content_video.isOpened():
83
  ret, content_image = content_video.read()
84
- if not ret: # Failed to read a frame
 
85
  break
86
 
87
  content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
@@ -96,6 +100,7 @@ def main():
96
  out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
97
  out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC)
98
 
 
99
  writer.append_data(np.array(out_tensor))
100
  video_tqdm.update(1)
101
 
 
55
  style_image_pth = Path(args.style_image)
56
  style_image = Image.open(style_image_pth)
57
 
58
+ # Read video info
59
  fps = int(content_video.get(cv2.CAP_PROP_FPS))
60
  frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT))
61
  video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
62
  video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH))
63
 
64
+ # Prepare loop
65
  video_tqdm = tqdm(frame_count)
66
 
67
+ # Prepare output video writer
68
  out_dir = './results_video/'
69
  os.makedirs(out_dir, exist_ok=True)
70
  out_pth = Path(out_dir + content_video_pth.stem + '_style_' \
 
84
 
85
  while content_video.isOpened():
86
  ret, content_image = content_video.read()
87
+ # Failed to read a frame
88
+ if not ret:
89
  break
90
 
91
  content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
 
100
  out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
101
  out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC)
102
 
103
+ # Write output frame to video
104
  writer.append_data(np.array(out_tensor))
105
  video_tqdm.update(1)
106
 
train.py CHANGED
@@ -17,21 +17,24 @@ def main():
17
  args = parser.parse_args()
18
 
19
  device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
20
-
21
  check_point_dir = './check_point/'
22
  weights_dir = './weights/'
 
 
23
  train_set = TrainSet(args.content_dir, args.style_dir)
24
  train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
25
-
 
26
  vgg_model = torch.load('vgg_normalized.pth')
27
  model = AdaINNet(vgg_model).to(device)
28
-
29
  decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
 
30
  total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
31
  losses = []
32
  iteration = 0
33
 
34
- # If resume
35
  if args.resume > 0:
36
  states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
37
  model.decoder.load_state_dict(states['decoder'])
@@ -54,10 +57,14 @@ def main():
54
  content_batch = content_batch.to(device)
55
  style_batch = style_batch.to(device)
56
 
 
57
  loss_content, loss_style = model(content_batch, style_batch)
58
  loss_scaled = loss_content + 10 * loss_style
 
 
59
  loss_scaled.backward()
60
  decoder_optimizer.step()
 
61
  total_loss = loss_scaled.item()
62
  content_loss = loss_content.item()
63
  style_loss = loss_style.item()
 
17
  args = parser.parse_args()
18
 
19
  device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
20
+
21
  check_point_dir = './check_point/'
22
  weights_dir = './weights/'
23
+
24
+ # Prepare Training dataset
25
  train_set = TrainSet(args.content_dir, args.style_dir)
26
  train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
27
+
28
+ # load vgg19 weights
29
  vgg_model = torch.load('vgg_normalized.pth')
30
  model = AdaINNet(vgg_model).to(device)
 
31
  decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
32
+
33
  total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
34
  losses = []
35
  iteration = 0
36
 
37
+ # If resume training, load states
38
  if args.resume > 0:
39
  states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
40
  model.decoder.load_state_dict(states['decoder'])
 
57
  content_batch = content_batch.to(device)
58
  style_batch = style_batch.to(device)
59
 
60
+ # Feed forward and compute loss
61
  loss_content, loss_style = model(content_batch, style_batch)
62
  loss_scaled = loss_content + 10 * loss_style
63
+
64
+ # Gradient descent
65
  loss_scaled.backward()
66
  decoder_optimizer.step()
67
+
68
  total_loss = loss_scaled.item()
69
  content_loss = loss_content.item()
70
  style_loss = loss_style.item()