Chengkai Yang commited on
Commit
4e1a0f5
·
2 Parent(s): 07167e3 20803d9

Merge pull request #3 from MasaTate/main

Browse files
Files changed (6) hide show
  1. .gitignore +5 -0
  2. README.md +3 -1
  3. test.py +9 -2
  4. test_interpolate.py +12 -3
  5. test_video.py +11 -3
  6. utils.py +27 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #Ignore __pycache__
2
+ /__pycache__/
3
+
4
+ #Ignore results
5
+ /results*/
README.md CHANGED
@@ -73,7 +73,9 @@ optional arguments:
73
  To test style transfer interpolation, run the script test_interpolate.py. Specify `--style_image` with multiple paths separated by comma. Specify `--interpolation_weights` to interpolate once. All outputs are saved in `./results_interpolate/`. Specify `--grid_pth` to interpolate with different built-in weights and provide 4 style images.
74
 
75
  ```
76
- $ python test_interpolation.py --content_image $IMG --style_image $STYLE --interpolation_weights $WEIGHTS --cuda
 
 
77
 
78
  optional arguments:
79
  -h, --help show this help message and exit
 
73
  To test style transfer interpolation, run the script test_interpolate.py. Specify `--style_image` with multiple paths separated by comma. Specify `--interpolation_weights` to interpolate once. All outputs are saved in `./results_interpolate/`. Specify `--grid_pth` to interpolate with different built-in weights and provide 4 style images.
74
 
75
  ```
76
+
77
+ $ python test_interpolate.py --content_image $IMG --style_image $STYLE $WEIGHT --cuda
78
+
79
 
80
  optional arguments:
81
  -h, --help show this help message and exit
test.py CHANGED
@@ -8,7 +8,7 @@ from AdaIN import AdaINNet
8
  from PIL import Image
9
  from torchvision.utils import save_image
10
  from torchvision.transforms import ToPILImage
11
- from utils import adaptive_instance_normalization, grid_image, transform, Range
12
  from glob import glob
13
 
14
  parser = argparse.ArgumentParser()
@@ -20,6 +20,7 @@ parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='D
20
  parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
21
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
22
  parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
 
23
  args = parser.parse_args()
24
  assert args.content_image or args.content_dir
25
  assert args.style_image or args.style_dir
@@ -103,6 +104,10 @@ def main():
103
 
104
  style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
105
 
 
 
 
 
106
  # Start time
107
  tic = time.perf_counter()
108
 
@@ -117,7 +122,9 @@ def main():
117
  times.append(toc-tic)
118
 
119
  # Save image
120
- out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix
 
 
121
  save_image(out_tensor, out_pth)
122
 
123
  if args.grid_pth:
 
8
  from PIL import Image
9
  from torchvision.utils import save_image
10
  from torchvision.transforms import ToPILImage
11
+ from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
12
  from glob import glob
13
 
14
  parser = argparse.ArgumentParser()
 
20
  parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
21
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
22
  parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
23
+ parser.add_argument('--color_control', action='store_true', help='Preserve content color')
24
  args = parser.parse_args()
25
  assert args.content_image or args.content_dir
26
  assert args.style_image or args.style_dir
 
104
 
105
  style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
106
 
107
+ # Linear Histogram Matching if needed
108
+ if args.color_control:
109
+ style_tensor = linear_histogram_matching(content_tensor,style_tensor)
110
+
111
  # Start time
112
  tic = time.perf_counter()
113
 
 
122
  times.append(toc-tic)
123
 
124
  # Save image
125
+ out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha)
126
+ if args.color_control: out_pth += '_colorcontrol'
127
+ out_pth += content_pth.suffix
128
  save_image(out_tensor, out_pth)
129
 
130
  if args.grid_pth:
test_interpolate.py CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
7
  from AdaIN import AdaINNet
8
  from PIL import Image
9
  from torchvision.utils import save_image
10
- from utils import adaptive_instance_normalization, transform, Range, grid_image
11
  from glob import glob
12
 
13
  parser = argparse.ArgumentParser()
@@ -19,6 +19,7 @@ parser.add_argument('--interpolation_weights', type=str, help='Weights of interp
19
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
20
  parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
21
  if use grid mode, provide 4 style images')
 
22
  args = parser.parse_args()
23
  assert args.content_image
24
  assert args.style_image
@@ -106,7 +107,13 @@ def main():
106
  style_tensor = []
107
  for style_pth in style_pths:
108
  img = Image.open(style_pth)
109
- style_tensor.append(transform([512, 512])(img))
 
 
 
 
 
 
110
  style_tensor = torch.stack(style_tensor, dim=0).to(device)
111
 
112
  for inter_weight in inter_weights:
@@ -117,7 +124,9 @@ def main():
117
  print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
118
 
119
  # Save results
120
- out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight) + content_pth.suffix
 
 
121
  save_image(out_tensor, out_pth)
122
 
123
  if args.grid_pth:
 
7
  from AdaIN import AdaINNet
8
  from PIL import Image
9
  from torchvision.utils import save_image
10
+ from utils import adaptive_instance_normalization, transform,linear_histogram_matching, Range, grid_image
11
  from glob import glob
12
 
13
  parser = argparse.ArgumentParser()
 
19
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
20
  parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
21
  if use grid mode, provide 4 style images')
22
+ parser.add_argument('--color_control', action='store_true', help='Preserve content color')
23
  args = parser.parse_args()
24
  assert args.content_image
25
  assert args.style_image
 
107
  style_tensor = []
108
  for style_pth in style_pths:
109
  img = Image.open(style_pth)
110
+ if args.color_control:
111
+ img = transform([512,512])(img).unsqueeze(0)
112
+ img = linear_histogram_matching(content_tensor,img)
113
+ img = img.squeeze(0)
114
+ style_tensor.append(img)
115
+ else:
116
+ style_tensor.append(transform([512, 512])(img))
117
  style_tensor = torch.stack(style_tensor, dim=0).to(device)
118
 
119
  for inter_weight in inter_weights:
 
124
  print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
125
 
126
  # Save results
127
+ out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight)
128
+ if args.color_control: out_pth += '_colorcontrol'
129
+ out_pth += content_pth.suffix
130
  save_image(out_tensor, out_pth)
131
 
132
  if args.grid_pth:
test_video.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from pathlib import Path
5
  from AdaIN import AdaINNet
6
  from PIL import Image
7
- from utils import transform, adaptive_instance_normalization, Range
8
  import cv2
9
  import imageio
10
  import numpy as np
@@ -17,6 +17,7 @@ parser.add_argument('--style_image', type=str, required=True, help='Style image
17
  parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
18
  parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
19
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
 
20
  args = parser.parse_args()
21
 
22
  device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
@@ -67,8 +68,10 @@ def main():
67
  # Prepare output video writer
68
  out_dir = './results_video/'
69
  os.makedirs(out_dir, exist_ok=True)
70
- out_pth = Path(out_dir + content_video_pth.stem + '_style_' \
71
- + style_image_pth.stem + content_video_pth.suffix)
 
 
72
  writer = imageio.get_writer(out_pth, mode='I', fps=fps)
73
 
74
  # Load AdaIN model
@@ -82,6 +85,7 @@ def main():
82
  style_tensor = t(style_image).unsqueeze(0).to(device)
83
 
84
 
 
85
  while content_video.isOpened():
86
  ret, content_image = content_video.read()
87
  # Failed to read a frame
@@ -90,6 +94,10 @@ def main():
90
 
91
  content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
92
 
 
 
 
 
93
  with torch.no_grad():
94
  out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
95
  , model.decoder, args.alpha).cpu().detach().numpy()
 
4
  from pathlib import Path
5
  from AdaIN import AdaINNet
6
  from PIL import Image
7
+ from utils import transform, adaptive_instance_normalization,linear_histogram_matching, Range
8
  import cv2
9
  import imageio
10
  import numpy as np
 
17
  parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
18
  parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
19
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
20
+ parser.add_argument('--color_control', action='store_true', help='Preserve content color')
21
  args = parser.parse_args()
22
 
23
  device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
 
68
  # Prepare output video writer
69
  out_dir = './results_video/'
70
  os.makedirs(out_dir, exist_ok=True)
71
+ out_pth = out_dir + content_video_pth.stem + '_style_' + style_image_pth.stem
72
+ if args.color_control: out_pth += '_colorcontrol'
73
+ out_pth += content_video_pth.suffix
74
+ out_pth = Path(out_pth)
75
  writer = imageio.get_writer(out_pth, mode='I', fps=fps)
76
 
77
  # Load AdaIN model
 
85
  style_tensor = t(style_image).unsqueeze(0).to(device)
86
 
87
 
88
+
89
  while content_video.isOpened():
90
  ret, content_image = content_video.read()
91
  # Failed to read a frame
 
94
 
95
  content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
96
 
97
+ # Linear Histogram Matching if needed
98
+ if args.color_control:
99
+ style_tensor = linear_histogram_matching(content_tensor,style_tensor)
100
+
101
  with torch.no_grad():
102
  out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
103
  , model.decoder, args.alpha).cpu().detach().numpy()
utils.py CHANGED
@@ -74,6 +74,33 @@ def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
74
  plt.savefig(save_pth)
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class TrainSet(Dataset):
78
  """
79
  Build Training dataset
 
74
  plt.savefig(save_pth)
75
 
76
 
77
+ def linear_histogram_matching(content_tensor, style_tensor):
78
+ """
79
+ Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
80
+
81
+ Args:
82
+ content_tensor (torch.FloatTensor): Content image
83
+ style_tensor (torch.FloatTensor): Style Image
84
+
85
+ Return:
86
+ style_tensor (torch.FloatTensor): histogram matched Style Image
87
+ """
88
+ #for batch
89
+ for b in range(len(content_tensor)):
90
+ std_ct = []
91
+ std_st = []
92
+ mean_ct = []
93
+ mean_st = []
94
+ #for channel
95
+ for c in range(len(content_tensor[b])):
96
+ std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
97
+ mean_ct.append(torch.mean(content_tensor[b][c]))
98
+ std_st.append(torch.var(style_tensor[b][c],unbiased = False))
99
+ mean_st.append(torch.mean(style_tensor[b][c]))
100
+ style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
101
+ return style_tensor
102
+
103
+
104
  class TrainSet(Dataset):
105
  """
106
  Build Training dataset