MasaTate commited on
Commit
f671c93
·
1 Parent(s): b000f15

move linear histogram matching to utils.py

Browse files
Files changed (3) hide show
  1. .gitignore +5 -0
  2. test.py +1 -28
  3. utils.py +27 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #Ignore __pycache__
2
+ /__pycache__/
3
+
4
+ #Ignore results
5
+ /results/
test.py CHANGED
@@ -8,7 +8,7 @@ from AdaIN import AdaINNet
8
  from PIL import Image
9
  from torchvision.utils import save_image
10
  from torchvision.transforms import ToPILImage
11
- from utils import adaptive_instance_normalization, grid_image, transform, Range
12
  from glob import glob
13
 
14
  parser = argparse.ArgumentParser()
@@ -55,33 +55,6 @@ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
55
  return decoder(mix_enc)
56
 
57
 
58
- def linear_histogram_matching(content_tensor, style_tensor):
59
- """
60
- Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
61
-
62
- Args:
63
- content_tensor (torch.FloatTensor): Content image
64
- style_tensor (torch.FloatTensor): Style Image
65
-
66
- Return:
67
- style_tensor (torch.FloatTensor): histogram matched Style Image
68
- """
69
- #for batch
70
- for b in range(len(content_tensor)):
71
- std_ct = []
72
- std_st = []
73
- mean_ct = []
74
- mean_st = []
75
- #for channel
76
- for c in range(len(content_tensor[b])):
77
- std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
78
- mean_ct.append(torch.mean(content_tensor[b][c]))
79
- std_st.append(torch.var(style_tensor[b][c],unbiased = False))
80
- mean_st.append(torch.mean(style_tensor[b][c]))
81
- style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
82
- return style_tensor
83
-
84
-
85
  def main():
86
  # Read content images and style images
87
  if args.content_image:
 
8
  from PIL import Image
9
  from torchvision.utils import save_image
10
  from torchvision.transforms import ToPILImage
11
+ from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
12
  from glob import glob
13
 
14
  parser = argparse.ArgumentParser()
 
55
  return decoder(mix_enc)
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def main():
59
  # Read content images and style images
60
  if args.content_image:
utils.py CHANGED
@@ -74,6 +74,33 @@ def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
74
  plt.savefig(save_pth)
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class TrainSet(Dataset):
78
  """
79
  Build Training dataset
 
74
  plt.savefig(save_pth)
75
 
76
 
77
+ def linear_histogram_matching(content_tensor, style_tensor):
78
+ """
79
+ Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
80
+
81
+ Args:
82
+ content_tensor (torch.FloatTensor): Content image
83
+ style_tensor (torch.FloatTensor): Style Image
84
+
85
+ Return:
86
+ style_tensor (torch.FloatTensor): histogram matched Style Image
87
+ """
88
+ #for batch
89
+ for b in range(len(content_tensor)):
90
+ std_ct = []
91
+ std_st = []
92
+ mean_ct = []
93
+ mean_st = []
94
+ #for channel
95
+ for c in range(len(content_tensor[b])):
96
+ std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
97
+ mean_ct.append(torch.mean(content_tensor[b][c]))
98
+ std_st.append(torch.var(style_tensor[b][c],unbiased = False))
99
+ mean_st.append(torch.mean(style_tensor[b][c]))
100
+ style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
101
+ return style_tensor
102
+
103
+
104
  class TrainSet(Dataset):
105
  """
106
  Build Training dataset