MasaTate commited on
Commit
cfa18c8
·
1 Parent(s): 3a1d3f5

add linear histogram matching to test_video.py

Browse files
Files changed (1) hide show
  1. test_video.py +11 -3
test_video.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from pathlib import Path
5
  from AdaIN import AdaINNet
6
  from PIL import Image
7
- from utils import transform, adaptive_instance_normalization, Range
8
  import cv2
9
  import imageio
10
  import numpy as np
@@ -17,6 +17,7 @@ parser.add_argument('--style_image', type=str, required=True, help='Style image
17
  parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
18
  parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
19
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
 
20
  args = parser.parse_args()
21
 
22
  device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
@@ -67,8 +68,10 @@ def main():
67
  # Prepare output video writer
68
  out_dir = './results_video/'
69
  os.makedirs(out_dir, exist_ok=True)
70
- out_pth = Path(out_dir + content_video_pth.stem + '_style_' \
71
- + style_image_pth.stem + content_video_pth.suffix)
 
 
72
  writer = imageio.get_writer(out_pth, mode='I', fps=fps)
73
 
74
  # Load AdaIN model
@@ -82,6 +85,7 @@ def main():
82
  style_tensor = t(style_image).unsqueeze(0).to(device)
83
 
84
 
 
85
  while content_video.isOpened():
86
  ret, content_image = content_video.read()
87
  # Failed to read a frame
@@ -90,6 +94,10 @@ def main():
90
 
91
  content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
92
 
 
 
 
 
93
  with torch.no_grad():
94
  out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
95
  , model.decoder, args.alpha).cpu().detach().numpy()
 
4
  from pathlib import Path
5
  from AdaIN import AdaINNet
6
  from PIL import Image
7
+ from utils import transform, adaptive_instance_normalization,linear_histogram_matching, Range
8
  import cv2
9
  import imageio
10
  import numpy as np
 
17
  parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
18
  parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
19
  parser.add_argument('--cuda', action='store_true', help='Use CUDA')
20
+ parser.add_argument('--color_control', action='store_true', help='Preserve content color')
21
  args = parser.parse_args()
22
 
23
  device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
 
68
  # Prepare output video writer
69
  out_dir = './results_video/'
70
  os.makedirs(out_dir, exist_ok=True)
71
+ out_pth = out_dir + content_video_pth.stem + '_style_' + style_image_pth.stem
72
+ if args.color_control: out_pth += '_colorcontrol'
73
+ out_pth += content_video_pth.suffix
74
+ out_pth = Path(out_pth)
75
  writer = imageio.get_writer(out_pth, mode='I', fps=fps)
76
 
77
  # Load AdaIN model
 
85
  style_tensor = t(style_image).unsqueeze(0).to(device)
86
 
87
 
88
+
89
  while content_video.isOpened():
90
  ret, content_image = content_video.read()
91
  # Failed to read a frame
 
94
 
95
  content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
96
 
97
+ # Linear Histogram Matching if needed
98
+ if args.color_control:
99
+ style_tensor = linear_histogram_matching(content_tensor,style_tensor)
100
+
101
  with torch.no_grad():
102
  out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
103
  , model.decoder, args.alpha).cpu().detach().numpy()