add color control to test_interpolte.py
Browse files- .gitignore +1 -1
- README.md +1 -1
- test_interpolate.py +12 -3
.gitignore
CHANGED
@@ -2,4 +2,4 @@
|
|
2 |
/__pycache__/
|
3 |
|
4 |
#Ignore results
|
5 |
-
/results
|
|
|
2 |
/__pycache__/
|
3 |
|
4 |
#Ignore results
|
5 |
+
/results*/
|
README.md
CHANGED
@@ -73,7 +73,7 @@ optional arguments:
|
|
73 |
To test style transfer interpolation, run the script test_interpolate.py. Specify `--style_image` with multiple paths separated by comma. Specify `--interpolation_weights` to interpolate once. All outputs are saved in `./results_interpolate/`. Specify `--grid_pth` to interpolate with different built-in weights and provide 4 style images.
|
74 |
|
75 |
```
|
76 |
-
$ python
|
77 |
|
78 |
optional arguments:
|
79 |
-h, --help show this help message and exit
|
|
|
73 |
To test style transfer interpolation, run the script test_interpolate.py. Specify `--style_image` with multiple paths separated by comma. Specify `--interpolation_weights` to interpolate once. All outputs are saved in `./results_interpolate/`. Specify `--grid_pth` to interpolate with different built-in weights and provide 4 style images.
|
74 |
|
75 |
```
|
76 |
+
$ python test_interpolate.py --content_image $IMG --style_image $STYLE $WEIGHT --cuda
|
77 |
|
78 |
optional arguments:
|
79 |
-h, --help show this help message and exit
|
test_interpolate.py
CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
|
|
7 |
from AdaIN import AdaINNet
|
8 |
from PIL import Image
|
9 |
from torchvision.utils import save_image
|
10 |
-
from utils import adaptive_instance_normalization, transform, Range, grid_image
|
11 |
from glob import glob
|
12 |
|
13 |
parser = argparse.ArgumentParser()
|
@@ -19,6 +19,7 @@ parser.add_argument('--interpolation_weights', type=str, help='Weights of interp
|
|
19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
20 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
|
21 |
if use grid mode, provide 4 style images')
|
|
|
22 |
args = parser.parse_args()
|
23 |
assert args.content_image
|
24 |
assert args.style_image
|
@@ -106,7 +107,13 @@ def main():
|
|
106 |
style_tensor = []
|
107 |
for style_pth in style_pths:
|
108 |
img = Image.open(style_pth)
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
style_tensor = torch.stack(style_tensor, dim=0).to(device)
|
111 |
|
112 |
for inter_weight in inter_weights:
|
@@ -117,7 +124,9 @@ def main():
|
|
117 |
print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
|
118 |
|
119 |
# Save results
|
120 |
-
out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight)
|
|
|
|
|
121 |
save_image(out_tensor, out_pth)
|
122 |
|
123 |
if args.grid_pth:
|
|
|
7 |
from AdaIN import AdaINNet
|
8 |
from PIL import Image
|
9 |
from torchvision.utils import save_image
|
10 |
+
from utils import adaptive_instance_normalization, transform,linear_histogram_matching, Range, grid_image
|
11 |
from glob import glob
|
12 |
|
13 |
parser = argparse.ArgumentParser()
|
|
|
19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
20 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
|
21 |
if use grid mode, provide 4 style images')
|
22 |
+
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
23 |
args = parser.parse_args()
|
24 |
assert args.content_image
|
25 |
assert args.style_image
|
|
|
107 |
style_tensor = []
|
108 |
for style_pth in style_pths:
|
109 |
img = Image.open(style_pth)
|
110 |
+
if args.color_control:
|
111 |
+
img = transform([512,512])(img).unsqueeze(0)
|
112 |
+
img = linear_histogram_matching(content_tensor,img)
|
113 |
+
img = img.squeeze(0)
|
114 |
+
style_tensor.append(img)
|
115 |
+
else:
|
116 |
+
style_tensor.append(transform([512, 512])(img))
|
117 |
style_tensor = torch.stack(style_tensor, dim=0).to(device)
|
118 |
|
119 |
for inter_weight in inter_weights:
|
|
|
124 |
print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
|
125 |
|
126 |
# Save results
|
127 |
+
out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight)
|
128 |
+
if args.color_control: out_pth += '_colorcontrol'
|
129 |
+
out_pth += content_pth.suffix
|
130 |
save_image(out_tensor, out_pth)
|
131 |
|
132 |
if args.grid_pth:
|