add linear histogram matching to test_video.py
Browse files- test_video.py +11 -3
test_video.py
CHANGED
@@ -4,7 +4,7 @@ import torch
|
|
4 |
from pathlib import Path
|
5 |
from AdaIN import AdaINNet
|
6 |
from PIL import Image
|
7 |
-
from utils import transform, adaptive_instance_normalization, Range
|
8 |
import cv2
|
9 |
import imageio
|
10 |
import numpy as np
|
@@ -17,6 +17,7 @@ parser.add_argument('--style_image', type=str, required=True, help='Style image
|
|
17 |
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
|
18 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
|
|
20 |
args = parser.parse_args()
|
21 |
|
22 |
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
@@ -67,8 +68,10 @@ def main():
|
|
67 |
# Prepare output video writer
|
68 |
out_dir = './results_video/'
|
69 |
os.makedirs(out_dir, exist_ok=True)
|
70 |
-
out_pth =
|
71 |
-
|
|
|
|
|
72 |
writer = imageio.get_writer(out_pth, mode='I', fps=fps)
|
73 |
|
74 |
# Load AdaIN model
|
@@ -82,6 +85,7 @@ def main():
|
|
82 |
style_tensor = t(style_image).unsqueeze(0).to(device)
|
83 |
|
84 |
|
|
|
85 |
while content_video.isOpened():
|
86 |
ret, content_image = content_video.read()
|
87 |
# Failed to read a frame
|
@@ -90,6 +94,10 @@ def main():
|
|
90 |
|
91 |
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
|
92 |
|
|
|
|
|
|
|
|
|
93 |
with torch.no_grad():
|
94 |
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
|
95 |
, model.decoder, args.alpha).cpu().detach().numpy()
|
|
|
4 |
from pathlib import Path
|
5 |
from AdaIN import AdaINNet
|
6 |
from PIL import Image
|
7 |
+
from utils import transform, adaptive_instance_normalization,linear_histogram_matching, Range
|
8 |
import cv2
|
9 |
import imageio
|
10 |
import numpy as np
|
|
|
17 |
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
|
18 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
20 |
+
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
21 |
args = parser.parse_args()
|
22 |
|
23 |
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
|
|
68 |
# Prepare output video writer
|
69 |
out_dir = './results_video/'
|
70 |
os.makedirs(out_dir, exist_ok=True)
|
71 |
+
out_pth = out_dir + content_video_pth.stem + '_style_' + style_image_pth.stem
|
72 |
+
if args.color_control: out_pth += '_colorcontrol'
|
73 |
+
out_pth += content_video_pth.suffix
|
74 |
+
out_pth = Path(out_pth)
|
75 |
writer = imageio.get_writer(out_pth, mode='I', fps=fps)
|
76 |
|
77 |
# Load AdaIN model
|
|
|
85 |
style_tensor = t(style_image).unsqueeze(0).to(device)
|
86 |
|
87 |
|
88 |
+
|
89 |
while content_video.isOpened():
|
90 |
ret, content_image = content_video.read()
|
91 |
# Failed to read a frame
|
|
|
94 |
|
95 |
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
|
96 |
|
97 |
+
# Linear Histogram Matching if needed
|
98 |
+
if args.color_control:
|
99 |
+
style_tensor = linear_histogram_matching(content_tensor,style_tensor)
|
100 |
+
|
101 |
with torch.no_grad():
|
102 |
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
|
103 |
, model.decoder, args.alpha).cpu().detach().numpy()
|