MasaTate commited on
Commit
3a1d3f5
·
1 Parent(s): f671c93

add color control to test_interpolte.py

Browse files
Files changed (3) hide show
  1. .gitignore +1 -1
  2. README.md +1 -1
  3. test_interpolate.py +12 -3
.gitignore CHANGED
@@ -2,4 +2,4 @@
2
  /__pycache__/
3
 
4
  #Ignore results
5
- /results/
 
2
  /__pycache__/
3
 
4
  #Ignore results
5
+ /results*/
README.md CHANGED
@@ -73,7 +73,7 @@ optional arguments:
73
  To test style transfer interpolation, run the script test_interpolate.py. Specify `--style_image` with multiple paths separated by comma. Specify `--interpolation_weights` to interpolate once. All outputs are saved in `./results_interpolate/`. Specify `--grid_pth` to interpolate with different built-in weights and provide 4 style images.
74
 
75
  ```
76
- $ python test_interpolation.py --content_image $IMG --style_image $STYLE $WEIGHT --cuda
77
 
78
  optional arguments:
79
  -h, --help show this help message and exit
 
73
  To test style transfer interpolation, run the script test_interpolate.py. Specify `--style_image` with multiple paths separated by comma. Specify `--interpolation_weights` to interpolate once. All outputs are saved in `./results_interpolate/`. Specify `--grid_pth` to interpolate with different built-in weights and provide 4 style images.
74
 
75
  ```
76
+ $ python test_interpolate.py --content_image $IMG --style_image $STYLE $WEIGHT --cuda
77
 
78
  optional arguments:
79
  -h, --help show this help message and exit
test_interpolate.py CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
7
  from AdaIN import AdaINNet
8
  from PIL import Image
9
  from torchvision.utils import save_image
10
- from utils import adaptive_instance_normalization, transform, Range, grid_image
11
  from glob import glob
12
 
13
  parser = argparse.ArgumentParser()
@@ -19,6 +19,7 @@ parser.add_argument('--interpolation_weights', type=str, help='Weights of interp
19
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
20
  parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
21
  if use grid mode, provide 4 style images')
 
22
  args = parser.parse_args()
23
  assert args.content_image
24
  assert args.style_image
@@ -106,7 +107,13 @@ def main():
106
  style_tensor = []
107
  for style_pth in style_pths:
108
  img = Image.open(style_pth)
109
- style_tensor.append(transform([512, 512])(img))
 
 
 
 
 
 
110
  style_tensor = torch.stack(style_tensor, dim=0).to(device)
111
 
112
  for inter_weight in inter_weights:
@@ -117,7 +124,9 @@ def main():
117
  print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
118
 
119
  # Save results
120
- out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight) + content_pth.suffix
 
 
121
  save_image(out_tensor, out_pth)
122
 
123
  if args.grid_pth:
 
7
  from AdaIN import AdaINNet
8
  from PIL import Image
9
  from torchvision.utils import save_image
10
+ from utils import adaptive_instance_normalization, transform,linear_histogram_matching, Range, grid_image
11
  from glob import glob
12
 
13
  parser = argparse.ArgumentParser()
 
19
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
20
  parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
21
  if use grid mode, provide 4 style images')
22
+ parser.add_argument('--color_control', action='store_true', help='Preserve content color')
23
  args = parser.parse_args()
24
  assert args.content_image
25
  assert args.style_image
 
107
  style_tensor = []
108
  for style_pth in style_pths:
109
  img = Image.open(style_pth)
110
+ if args.color_control:
111
+ img = transform([512,512])(img).unsqueeze(0)
112
+ img = linear_histogram_matching(content_tensor,img)
113
+ img = img.squeeze(0)
114
+ style_tensor.append(img)
115
+ else:
116
+ style_tensor.append(transform([512, 512])(img))
117
  style_tensor = torch.stack(style_tensor, dim=0).to(device)
118
 
119
  for inter_weight in inter_weights:
 
124
  print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
125
 
126
  # Save results
127
+ out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight)
128
+ if args.color_control: out_pth += '_colorcontrol'
129
+ out_pth += content_pth.suffix
130
  save_image(out_tensor, out_pth)
131
 
132
  if args.grid_pth: