move linear histogram matching to utils.py
Browse files- .gitignore +5 -0
- test.py +1 -28
- utils.py +27 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#Ignore __pycache__
|
2 |
+
/__pycache__/
|
3 |
+
|
4 |
+
#Ignore results
|
5 |
+
/results/
|
test.py
CHANGED
@@ -8,7 +8,7 @@ from AdaIN import AdaINNet
|
|
8 |
from PIL import Image
|
9 |
from torchvision.utils import save_image
|
10 |
from torchvision.transforms import ToPILImage
|
11 |
-
from utils import adaptive_instance_normalization, grid_image, transform, Range
|
12 |
from glob import glob
|
13 |
|
14 |
parser = argparse.ArgumentParser()
|
@@ -55,33 +55,6 @@ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
|
|
55 |
return decoder(mix_enc)
|
56 |
|
57 |
|
58 |
-
def linear_histogram_matching(content_tensor, style_tensor):
|
59 |
-
"""
|
60 |
-
Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
content_tensor (torch.FloatTensor): Content image
|
64 |
-
style_tensor (torch.FloatTensor): Style Image
|
65 |
-
|
66 |
-
Return:
|
67 |
-
style_tensor (torch.FloatTensor): histogram matched Style Image
|
68 |
-
"""
|
69 |
-
#for batch
|
70 |
-
for b in range(len(content_tensor)):
|
71 |
-
std_ct = []
|
72 |
-
std_st = []
|
73 |
-
mean_ct = []
|
74 |
-
mean_st = []
|
75 |
-
#for channel
|
76 |
-
for c in range(len(content_tensor[b])):
|
77 |
-
std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
|
78 |
-
mean_ct.append(torch.mean(content_tensor[b][c]))
|
79 |
-
std_st.append(torch.var(style_tensor[b][c],unbiased = False))
|
80 |
-
mean_st.append(torch.mean(style_tensor[b][c]))
|
81 |
-
style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
|
82 |
-
return style_tensor
|
83 |
-
|
84 |
-
|
85 |
def main():
|
86 |
# Read content images and style images
|
87 |
if args.content_image:
|
|
|
8 |
from PIL import Image
|
9 |
from torchvision.utils import save_image
|
10 |
from torchvision.transforms import ToPILImage
|
11 |
+
from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
|
12 |
from glob import glob
|
13 |
|
14 |
parser = argparse.ArgumentParser()
|
|
|
55 |
return decoder(mix_enc)
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def main():
|
59 |
# Read content images and style images
|
60 |
if args.content_image:
|
utils.py
CHANGED
@@ -74,6 +74,33 @@ def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
|
|
74 |
plt.savefig(save_pth)
|
75 |
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
class TrainSet(Dataset):
|
78 |
"""
|
79 |
Build Training dataset
|
|
|
74 |
plt.savefig(save_pth)
|
75 |
|
76 |
|
77 |
+
def linear_histogram_matching(content_tensor, style_tensor):
|
78 |
+
"""
|
79 |
+
Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
content_tensor (torch.FloatTensor): Content image
|
83 |
+
style_tensor (torch.FloatTensor): Style Image
|
84 |
+
|
85 |
+
Return:
|
86 |
+
style_tensor (torch.FloatTensor): histogram matched Style Image
|
87 |
+
"""
|
88 |
+
#for batch
|
89 |
+
for b in range(len(content_tensor)):
|
90 |
+
std_ct = []
|
91 |
+
std_st = []
|
92 |
+
mean_ct = []
|
93 |
+
mean_st = []
|
94 |
+
#for channel
|
95 |
+
for c in range(len(content_tensor[b])):
|
96 |
+
std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
|
97 |
+
mean_ct.append(torch.mean(content_tensor[b][c]))
|
98 |
+
std_st.append(torch.var(style_tensor[b][c],unbiased = False))
|
99 |
+
mean_st.append(torch.mean(style_tensor[b][c]))
|
100 |
+
style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
|
101 |
+
return style_tensor
|
102 |
+
|
103 |
+
|
104 |
class TrainSet(Dataset):
|
105 |
"""
|
106 |
Build Training dataset
|