Merge pull request #3 from MasaTate/main
Browse files- .gitignore +5 -0
- README.md +3 -1
- test.py +9 -2
- test_interpolate.py +12 -3
- test_video.py +11 -3
- utils.py +27 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#Ignore __pycache__
|
2 |
+
/__pycache__/
|
3 |
+
|
4 |
+
#Ignore results
|
5 |
+
/results*/
|
README.md
CHANGED
@@ -73,7 +73,9 @@ optional arguments:
|
|
73 |
To test style transfer interpolation, run the script test_interpolate.py. Specify `--style_image` with multiple paths separated by comma. Specify `--interpolation_weights` to interpolate once. All outputs are saved in `./results_interpolate/`. Specify `--grid_pth` to interpolate with different built-in weights and provide 4 style images.
|
74 |
|
75 |
```
|
76 |
-
|
|
|
|
|
77 |
|
78 |
optional arguments:
|
79 |
-h, --help show this help message and exit
|
|
|
73 |
To test style transfer interpolation, run the script test_interpolate.py. Specify `--style_image` with multiple paths separated by comma. Specify `--interpolation_weights` to interpolate once. All outputs are saved in `./results_interpolate/`. Specify `--grid_pth` to interpolate with different built-in weights and provide 4 style images.
|
74 |
|
75 |
```
|
76 |
+
|
77 |
+
$ python test_interpolate.py --content_image $IMG --style_image $STYLE $WEIGHT --cuda
|
78 |
+
|
79 |
|
80 |
optional arguments:
|
81 |
-h, --help show this help message and exit
|
test.py
CHANGED
@@ -8,7 +8,7 @@ from AdaIN import AdaINNet
|
|
8 |
from PIL import Image
|
9 |
from torchvision.utils import save_image
|
10 |
from torchvision.transforms import ToPILImage
|
11 |
-
from utils import adaptive_instance_normalization, grid_image, transform, Range
|
12 |
from glob import glob
|
13 |
|
14 |
parser = argparse.ArgumentParser()
|
@@ -20,6 +20,7 @@ parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='D
|
|
20 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
21 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
22 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
|
|
|
23 |
args = parser.parse_args()
|
24 |
assert args.content_image or args.content_dir
|
25 |
assert args.style_image or args.style_dir
|
@@ -103,6 +104,10 @@ def main():
|
|
103 |
|
104 |
style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
|
105 |
|
|
|
|
|
|
|
|
|
106 |
# Start time
|
107 |
tic = time.perf_counter()
|
108 |
|
@@ -117,7 +122,9 @@ def main():
|
|
117 |
times.append(toc-tic)
|
118 |
|
119 |
# Save image
|
120 |
-
out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha)
|
|
|
|
|
121 |
save_image(out_tensor, out_pth)
|
122 |
|
123 |
if args.grid_pth:
|
|
|
8 |
from PIL import Image
|
9 |
from torchvision.utils import save_image
|
10 |
from torchvision.transforms import ToPILImage
|
11 |
+
from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
|
12 |
from glob import glob
|
13 |
|
14 |
parser = argparse.ArgumentParser()
|
|
|
20 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
21 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
22 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
|
23 |
+
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
24 |
args = parser.parse_args()
|
25 |
assert args.content_image or args.content_dir
|
26 |
assert args.style_image or args.style_dir
|
|
|
104 |
|
105 |
style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
|
106 |
|
107 |
+
# Linear Histogram Matching if needed
|
108 |
+
if args.color_control:
|
109 |
+
style_tensor = linear_histogram_matching(content_tensor,style_tensor)
|
110 |
+
|
111 |
# Start time
|
112 |
tic = time.perf_counter()
|
113 |
|
|
|
122 |
times.append(toc-tic)
|
123 |
|
124 |
# Save image
|
125 |
+
out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha)
|
126 |
+
if args.color_control: out_pth += '_colorcontrol'
|
127 |
+
out_pth += content_pth.suffix
|
128 |
save_image(out_tensor, out_pth)
|
129 |
|
130 |
if args.grid_pth:
|
test_interpolate.py
CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
|
|
7 |
from AdaIN import AdaINNet
|
8 |
from PIL import Image
|
9 |
from torchvision.utils import save_image
|
10 |
-
from utils import adaptive_instance_normalization, transform, Range, grid_image
|
11 |
from glob import glob
|
12 |
|
13 |
parser = argparse.ArgumentParser()
|
@@ -19,6 +19,7 @@ parser.add_argument('--interpolation_weights', type=str, help='Weights of interp
|
|
19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
20 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
|
21 |
if use grid mode, provide 4 style images')
|
|
|
22 |
args = parser.parse_args()
|
23 |
assert args.content_image
|
24 |
assert args.style_image
|
@@ -106,7 +107,13 @@ def main():
|
|
106 |
style_tensor = []
|
107 |
for style_pth in style_pths:
|
108 |
img = Image.open(style_pth)
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
style_tensor = torch.stack(style_tensor, dim=0).to(device)
|
111 |
|
112 |
for inter_weight in inter_weights:
|
@@ -117,7 +124,9 @@ def main():
|
|
117 |
print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
|
118 |
|
119 |
# Save results
|
120 |
-
out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight)
|
|
|
|
|
121 |
save_image(out_tensor, out_pth)
|
122 |
|
123 |
if args.grid_pth:
|
|
|
7 |
from AdaIN import AdaINNet
|
8 |
from PIL import Image
|
9 |
from torchvision.utils import save_image
|
10 |
+
from utils import adaptive_instance_normalization, transform,linear_histogram_matching, Range, grid_image
|
11 |
from glob import glob
|
12 |
|
13 |
parser = argparse.ArgumentParser()
|
|
|
19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
20 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
|
21 |
if use grid mode, provide 4 style images')
|
22 |
+
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
23 |
args = parser.parse_args()
|
24 |
assert args.content_image
|
25 |
assert args.style_image
|
|
|
107 |
style_tensor = []
|
108 |
for style_pth in style_pths:
|
109 |
img = Image.open(style_pth)
|
110 |
+
if args.color_control:
|
111 |
+
img = transform([512,512])(img).unsqueeze(0)
|
112 |
+
img = linear_histogram_matching(content_tensor,img)
|
113 |
+
img = img.squeeze(0)
|
114 |
+
style_tensor.append(img)
|
115 |
+
else:
|
116 |
+
style_tensor.append(transform([512, 512])(img))
|
117 |
style_tensor = torch.stack(style_tensor, dim=0).to(device)
|
118 |
|
119 |
for inter_weight in inter_weights:
|
|
|
124 |
print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
|
125 |
|
126 |
# Save results
|
127 |
+
out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight)
|
128 |
+
if args.color_control: out_pth += '_colorcontrol'
|
129 |
+
out_pth += content_pth.suffix
|
130 |
save_image(out_tensor, out_pth)
|
131 |
|
132 |
if args.grid_pth:
|
test_video.py
CHANGED
@@ -4,7 +4,7 @@ import torch
|
|
4 |
from pathlib import Path
|
5 |
from AdaIN import AdaINNet
|
6 |
from PIL import Image
|
7 |
-
from utils import transform, adaptive_instance_normalization, Range
|
8 |
import cv2
|
9 |
import imageio
|
10 |
import numpy as np
|
@@ -17,6 +17,7 @@ parser.add_argument('--style_image', type=str, required=True, help='Style image
|
|
17 |
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
|
18 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
|
|
20 |
args = parser.parse_args()
|
21 |
|
22 |
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
@@ -67,8 +68,10 @@ def main():
|
|
67 |
# Prepare output video writer
|
68 |
out_dir = './results_video/'
|
69 |
os.makedirs(out_dir, exist_ok=True)
|
70 |
-
out_pth =
|
71 |
-
|
|
|
|
|
72 |
writer = imageio.get_writer(out_pth, mode='I', fps=fps)
|
73 |
|
74 |
# Load AdaIN model
|
@@ -82,6 +85,7 @@ def main():
|
|
82 |
style_tensor = t(style_image).unsqueeze(0).to(device)
|
83 |
|
84 |
|
|
|
85 |
while content_video.isOpened():
|
86 |
ret, content_image = content_video.read()
|
87 |
# Failed to read a frame
|
@@ -90,6 +94,10 @@ def main():
|
|
90 |
|
91 |
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
|
92 |
|
|
|
|
|
|
|
|
|
93 |
with torch.no_grad():
|
94 |
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
|
95 |
, model.decoder, args.alpha).cpu().detach().numpy()
|
|
|
4 |
from pathlib import Path
|
5 |
from AdaIN import AdaINNet
|
6 |
from PIL import Image
|
7 |
+
from utils import transform, adaptive_instance_normalization,linear_histogram_matching, Range
|
8 |
import cv2
|
9 |
import imageio
|
10 |
import numpy as np
|
|
|
17 |
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
|
18 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
20 |
+
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
21 |
args = parser.parse_args()
|
22 |
|
23 |
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
|
|
68 |
# Prepare output video writer
|
69 |
out_dir = './results_video/'
|
70 |
os.makedirs(out_dir, exist_ok=True)
|
71 |
+
out_pth = out_dir + content_video_pth.stem + '_style_' + style_image_pth.stem
|
72 |
+
if args.color_control: out_pth += '_colorcontrol'
|
73 |
+
out_pth += content_video_pth.suffix
|
74 |
+
out_pth = Path(out_pth)
|
75 |
writer = imageio.get_writer(out_pth, mode='I', fps=fps)
|
76 |
|
77 |
# Load AdaIN model
|
|
|
85 |
style_tensor = t(style_image).unsqueeze(0).to(device)
|
86 |
|
87 |
|
88 |
+
|
89 |
while content_video.isOpened():
|
90 |
ret, content_image = content_video.read()
|
91 |
# Failed to read a frame
|
|
|
94 |
|
95 |
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
|
96 |
|
97 |
+
# Linear Histogram Matching if needed
|
98 |
+
if args.color_control:
|
99 |
+
style_tensor = linear_histogram_matching(content_tensor,style_tensor)
|
100 |
+
|
101 |
with torch.no_grad():
|
102 |
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
|
103 |
, model.decoder, args.alpha).cpu().detach().numpy()
|
utils.py
CHANGED
@@ -74,6 +74,33 @@ def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
|
|
74 |
plt.savefig(save_pth)
|
75 |
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
class TrainSet(Dataset):
|
78 |
"""
|
79 |
Build Training dataset
|
|
|
74 |
plt.savefig(save_pth)
|
75 |
|
76 |
|
77 |
+
def linear_histogram_matching(content_tensor, style_tensor):
|
78 |
+
"""
|
79 |
+
Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
content_tensor (torch.FloatTensor): Content image
|
83 |
+
style_tensor (torch.FloatTensor): Style Image
|
84 |
+
|
85 |
+
Return:
|
86 |
+
style_tensor (torch.FloatTensor): histogram matched Style Image
|
87 |
+
"""
|
88 |
+
#for batch
|
89 |
+
for b in range(len(content_tensor)):
|
90 |
+
std_ct = []
|
91 |
+
std_st = []
|
92 |
+
mean_ct = []
|
93 |
+
mean_st = []
|
94 |
+
#for channel
|
95 |
+
for c in range(len(content_tensor[b])):
|
96 |
+
std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
|
97 |
+
mean_ct.append(torch.mean(content_tensor[b][c]))
|
98 |
+
std_st.append(torch.var(style_tensor[b][c],unbiased = False))
|
99 |
+
mean_st.append(torch.mean(style_tensor[b][c]))
|
100 |
+
style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
|
101 |
+
return style_tensor
|
102 |
+
|
103 |
+
|
104 |
class TrainSet(Dataset):
|
105 |
"""
|
106 |
Build Training dataset
|