Dat Nguyen-Tien commited on
Commit
a38a6fb
·
1 Parent(s): 476c59d

upload_huggff

Browse files
Files changed (42) hide show
  1. __init__.py +1 -0
  2. __pycache__/__init__.cpython-310.pyc +0 -0
  3. __pycache__/function.cpython-310.pyc +0 -0
  4. __pycache__/sampler.cpython-310.pyc +0 -0
  5. experiments/decoder_iter_160000.pth +3 -0
  6. experiments/embedding_iter_160000.pth +3 -0
  7. experiments/transformer_iter_160000.pth +3 -0
  8. experiments/vgg_normalised.pth +3 -0
  9. function.py +73 -0
  10. images/images_image1.jpg +0 -0
  11. logs/events.out.tfevents.1717260856.NGUYENTIENDAT +3 -0
  12. logs/events.out.tfevents.1717261358.NGUYENTIENDAT +3 -0
  13. logs/events.out.tfevents.1717261528.NGUYENTIENDAT +3 -0
  14. logs/events.out.tfevents.1717262282.NGUYENTIENDAT +3 -0
  15. logs/events.out.tfevents.1717262831.NGUYENTIENDAT +3 -0
  16. logs/events.out.tfevents.1717262870.NGUYENTIENDAT +3 -0
  17. logs/events.out.tfevents.1717300182.NGUYENTIENDAT +3 -0
  18. logs/events.out.tfevents.1717300226.NGUYENTIENDAT +3 -0
  19. logs/events.out.tfevents.1717300229.NGUYENTIENDAT +3 -0
  20. logs/events.out.tfevents.1717300284.NGUYENTIENDAT +3 -0
  21. logs/events.out.tfevents.1717300608.NGUYENTIENDAT +3 -0
  22. logs/events.out.tfevents.1717300611.NGUYENTIENDAT +3 -0
  23. models/StyTR.py +251 -0
  24. models/ViT_helper.py +117 -0
  25. models/__pycache__/StyTR.cpython-310.pyc +0 -0
  26. models/__pycache__/ViT_helper.cpython-310.pyc +0 -0
  27. models/__pycache__/__init__.cpython-310.pyc +0 -0
  28. models/__pycache__/transformer.cpython-310.pyc +0 -0
  29. models/sampler.py +26 -0
  30. models/transformer.py +322 -0
  31. outputs/test/0.jpg +0 -0
  32. sampler.py +26 -0
  33. style/style_image2.jpg +0 -0
  34. test.py +183 -0
  35. train.py +210 -0
  36. util/__init__.py +1 -0
  37. util/__pycache__/__init__.cpython-310.pyc +0 -0
  38. util/__pycache__/box_ops.cpython-310.pyc +0 -0
  39. util/__pycache__/misc.cpython-310.pyc +0 -0
  40. util/box_ops.py +88 -0
  41. util/misc.py +468 -0
  42. util/plot_utils.py +107 -0
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (142 Bytes). View file
 
__pycache__/function.cpython-310.pyc ADDED
Binary file (2.17 kB). View file
 
__pycache__/sampler.cpython-310.pyc ADDED
Binary file (1.12 kB). View file
 
experiments/decoder_iter_160000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57133e62081041cfc8c16d921071832837844f69a01b1705c7511faa2d2b9eee
3
+ size 14027089
experiments/embedding_iter_160000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f56cb10d331e980654469c80715da5d9eac5e8455f8ed1df41b7141e0612d53a
3
+ size 396481
experiments/transformer_iter_160000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45316082fd7359864cb08bac63b7c8f278f881763ae6fd97416610c2cdca4e67
3
+ size 127208897
experiments/vgg_normalised.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:804ca2835ecf7539f0cd2a7ac3c18ce81e6f8468969ae7117ac0c148d286bb4a
3
+ size 80102481
function.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def calc_mean_std(feat, eps=1e-5):
5
+ # eps is a small value added to the variance to avoid divide-by-zero.
6
+ size = feat.size()
7
+ assert (len(size) == 4)
8
+ N, C = size[:2]
9
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
10
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
11
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
12
+ return feat_mean, feat_std
13
+
14
+ def calc_mean_std1(feat, eps=1e-5):
15
+ # eps is a small value added to the variance to avoid divide-by-zero.
16
+ size = feat.size()
17
+ # assert (len(size) == 4)
18
+ WH,N, C = size
19
+ feat_var = feat.var(dim=0) + eps
20
+ feat_std = feat_var.sqrt()
21
+ feat_mean = feat.mean(dim=0)
22
+ return feat_mean, feat_std
23
+ def normal(feat, eps=1e-5):
24
+ feat_mean, feat_std= calc_mean_std(feat, eps)
25
+ normalized=(feat-feat_mean)/feat_std
26
+ return normalized
27
+ def normal_style(feat, eps=1e-5):
28
+ feat_mean, feat_std= calc_mean_std1(feat, eps)
29
+ normalized=(feat-feat_mean)/feat_std
30
+ return normalized
31
+
32
+ def _calc_feat_flatten_mean_std(feat):
33
+ # takes 3D feat (C, H, W), return mean and std of array within channels
34
+ assert (feat.size()[0] == 3)
35
+ assert (isinstance(feat, torch.FloatTensor))
36
+ feat_flatten = feat.view(3, -1)
37
+ mean = feat_flatten.mean(dim=-1, keepdim=True)
38
+ std = feat_flatten.std(dim=-1, keepdim=True)
39
+ return feat_flatten, mean, std
40
+
41
+
42
+ def _mat_sqrt(x):
43
+ U, D, V = torch.svd(x)
44
+ return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
45
+
46
+
47
+ def coral(source, target):
48
+ # assume both source and target are 3D array (C, H, W)
49
+ # Note: flatten -> f
50
+
51
+ source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
52
+ source_f_norm = (source_f - source_f_mean.expand_as(
53
+ source_f)) / source_f_std.expand_as(source_f)
54
+ source_f_cov_eye = \
55
+ torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
56
+
57
+ target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
58
+ target_f_norm = (target_f - target_f_mean.expand_as(
59
+ target_f)) / target_f_std.expand_as(target_f)
60
+ target_f_cov_eye = \
61
+ torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
62
+
63
+ source_f_norm_transfer = torch.mm(
64
+ _mat_sqrt(target_f_cov_eye),
65
+ torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
66
+ source_f_norm)
67
+ )
68
+
69
+ source_f_transfer = source_f_norm_transfer * \
70
+ target_f_std.expand_as(source_f_norm) + \
71
+ target_f_mean.expand_as(source_f_norm)
72
+
73
+ return source_f_transfer.view(source.size())
images/images_image1.jpg ADDED
logs/events.out.tfevents.1717260856.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:561832434af350a6f0e0460f79365352c81825418a2f61f43edf7c5224c63c18
3
+ size 40
logs/events.out.tfevents.1717261358.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f8684900a20b584701c30b530b9275b9918bf2f040446e95aee1bf7ff0a505b
3
+ size 40
logs/events.out.tfevents.1717261528.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab1d65bcd8f36cf6678e8c25ea6ff3fa07db539ca9a0b59c2437bf020b9e5502
3
+ size 40
logs/events.out.tfevents.1717262282.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a0cc507c6909cfb2822bc3be0832c958efbd42f6f07317f1a613118593123ed
3
+ size 40
logs/events.out.tfevents.1717262831.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:727528a0119df4f847c66c0acc34d64f96a57a9bb8dda57bafc4a68070157685
3
+ size 40
logs/events.out.tfevents.1717262870.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2042827551df19eb1bd516c6e0a31acea34ae0c427d7f233064897db9d7de7c2
3
+ size 40
logs/events.out.tfevents.1717300182.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac2eda70720342dc7dde77eff301a8c653df2a2ad9f85d8311bb9112d9cbff0b
3
+ size 40
logs/events.out.tfevents.1717300226.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa75376b2ac29e81ef7a15af8d07ef6921dfbc07107f24d98daf195d675249d2
3
+ size 40
logs/events.out.tfevents.1717300229.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4226559cd8c53df827cbeb02f96f80d60058d1f7cb89f8750462901b2676924
3
+ size 40
logs/events.out.tfevents.1717300284.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f17141f226ef2ba638e4714de753ecae20ce00b091f5f1c784ccb1247464264f
3
+ size 790
logs/events.out.tfevents.1717300608.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba8bbc1a99b7c86c3fad38911ee90e762ab94cfd6d3601dabc2d3472d002fe81
3
+ size 40
logs/events.out.tfevents.1717300611.NGUYENTIENDAT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5797aba72e948834e99c14dd5c467030a7ba4cd979b8abf381132391da8e946c
3
+ size 40
models/StyTR.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ import numpy as np
5
+ from util import box_ops
6
+ from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
7
+ accuracy, get_world_size, interpolate,
8
+ is_dist_avail_and_initialized)
9
+ from function import normal,normal_style
10
+ from function import calc_mean_std
11
+ import scipy.stats as stats
12
+ from models.ViT_helper import DropPath, to_2tuple, trunc_normal_
13
+
14
+ class PatchEmbed(nn.Module):
15
+ """ Image to Patch Embedding
16
+ """
17
+ def __init__(self, img_size=256, patch_size=8, in_chans=3, embed_dim=512):
18
+ super().__init__()
19
+ img_size = to_2tuple(img_size)
20
+ patch_size = to_2tuple(patch_size)
21
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
22
+ self.img_size = img_size
23
+ self.patch_size = patch_size
24
+ self.num_patches = num_patches
25
+
26
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
27
+ self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
28
+
29
+ def forward(self, x):
30
+ B, C, H, W = x.shape
31
+ print(f"PatchEmbed Input: {x.shape}")
32
+ x = self.proj(x)
33
+ print(f"PatchEmbed Output: {x.shape}")
34
+
35
+ return x
36
+
37
+
38
+ decoder = nn.Sequential(
39
+ nn.ReflectionPad2d((1, 1, 1, 1)),
40
+ nn.Conv2d(512, 256, (3, 3)),
41
+ nn.ReLU(),
42
+ nn.Upsample(scale_factor=2, mode='nearest'),
43
+ nn.ReflectionPad2d((1, 1, 1, 1)),
44
+ nn.Conv2d(256, 256, (3, 3)),
45
+ nn.ReLU(),
46
+ nn.ReflectionPad2d((1, 1, 1, 1)),
47
+ nn.Conv2d(256, 256, (3, 3)),
48
+ nn.ReLU(),
49
+ nn.ReflectionPad2d((1, 1, 1, 1)),
50
+ nn.Conv2d(256, 256, (3, 3)),
51
+ nn.ReLU(),
52
+ nn.ReflectionPad2d((1, 1, 1, 1)),
53
+ nn.Conv2d(256, 128, (3, 3)),
54
+ nn.ReLU(),
55
+ nn.Upsample(scale_factor=2, mode='nearest'),
56
+ nn.ReflectionPad2d((1, 1, 1, 1)),
57
+ nn.Conv2d(128, 128, (3, 3)),
58
+ nn.ReLU(),
59
+ nn.ReflectionPad2d((1, 1, 1, 1)),
60
+ nn.Conv2d(128, 64, (3, 3)),
61
+ nn.ReLU(),
62
+ nn.Upsample(scale_factor=2, mode='nearest'),
63
+ nn.ReflectionPad2d((1, 1, 1, 1)),
64
+ nn.Conv2d(64, 64, (3, 3)),
65
+ nn.ReLU(),
66
+ nn.ReflectionPad2d((1, 1, 1, 1)),
67
+ nn.Conv2d(64, 3, (3, 3)),
68
+ )
69
+ ################# IN SHAPE CUA MODEL
70
+ for name, module in decoder.named_children():
71
+ def hook(module, input, output):
72
+ print(f"{module.__class__.__name__} Input: {input[0].shape}")
73
+ print(f"{module.__class__.__name__} Output: {output.shape}")
74
+ module.register_forward_hook(hook)
75
+
76
+
77
+ vgg = nn.Sequential(
78
+ nn.Conv2d(3, 3, (1, 1)),
79
+ nn.ReflectionPad2d((1, 1, 1, 1)),
80
+ nn.Conv2d(3, 64, (3, 3)),
81
+ nn.ReLU(), # relu1-1
82
+ nn.ReflectionPad2d((1, 1, 1, 1)),
83
+ nn.Conv2d(64, 64, (3, 3)),
84
+ nn.ReLU(), # relu1-2
85
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
86
+ nn.ReflectionPad2d((1, 1, 1, 1)),
87
+ nn.Conv2d(64, 128, (3, 3)),
88
+ nn.ReLU(), # relu2-1
89
+ nn.ReflectionPad2d((1, 1, 1, 1)),
90
+ nn.Conv2d(128, 128, (3, 3)),
91
+ nn.ReLU(), # relu2-2
92
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
93
+ nn.ReflectionPad2d((1, 1, 1, 1)),
94
+ nn.Conv2d(128, 256, (3, 3)),
95
+ nn.ReLU(), # relu3-1
96
+ nn.ReflectionPad2d((1, 1, 1, 1)),
97
+ nn.Conv2d(256, 256, (3, 3)),
98
+ nn.ReLU(), # relu3-2
99
+ nn.ReflectionPad2d((1, 1, 1, 1)),
100
+ nn.Conv2d(256, 256, (3, 3)),
101
+ nn.ReLU(), # relu3-3
102
+ nn.ReflectionPad2d((1, 1, 1, 1)),
103
+ nn.Conv2d(256, 256, (3, 3)),
104
+ nn.ReLU(), # relu3-4
105
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
106
+ nn.ReflectionPad2d((1, 1, 1, 1)),
107
+ nn.Conv2d(256, 512, (3, 3)),
108
+ nn.ReLU(), # relu4-1, this is the last layer used
109
+ nn.ReflectionPad2d((1, 1, 1, 1)),
110
+ nn.Conv2d(512, 512, (3, 3)),
111
+ nn.ReLU(), # relu4-2
112
+ nn.ReflectionPad2d((1, 1, 1, 1)),
113
+ nn.Conv2d(512, 512, (3, 3)),
114
+ nn.ReLU(), # relu4-3
115
+ nn.ReflectionPad2d((1, 1, 1, 1)),
116
+ nn.Conv2d(512, 512, (3, 3)),
117
+ nn.ReLU(), # relu4-4
118
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
119
+ nn.ReflectionPad2d((1, 1, 1, 1)),
120
+ nn.Conv2d(512, 512, (3, 3)),
121
+ nn.ReLU(), # relu5-1
122
+ nn.ReflectionPad2d((1, 1, 1, 1)),
123
+ nn.Conv2d(512, 512, (3, 3)),
124
+ nn.ReLU(), # relu5-2
125
+ nn.ReflectionPad2d((1, 1, 1, 1)),
126
+ nn.Conv2d(512, 512, (3, 3)),
127
+ nn.ReLU(), # relu5-3
128
+ nn.ReflectionPad2d((1, 1, 1, 1)),
129
+ nn.Conv2d(512, 512, (3, 3)),
130
+ nn.ReLU() # relu5-4
131
+ )
132
+
133
+ ################# IN SHAPE CUA MODEL DECODER
134
+ for name, module in vgg.named_children():
135
+ def hook(module, input, output):
136
+ print(f"{module.__class__.__name__} Input: {input[0].shape}")
137
+ print(f"{module.__class__.__name__} Output: {output.shape}")
138
+ module.register_forward_hook(hook)
139
+
140
+ class MLP(nn.Module):
141
+ """ Very simple multi-layer perceptron (also called FFN)"""
142
+
143
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
144
+ super().__init__()
145
+ self.num_layers = num_layers
146
+ h = [hidden_dim] * (num_layers - 1)
147
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
148
+
149
+ def forward(self, x):
150
+ for i, layer in enumerate(self.layers):
151
+ print(f"MLP Layer {i} Input: {x.shape}")
152
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
153
+ print(f"MLP Layer {i} Output: {x.shape}")
154
+ return x
155
+ class StyTrans(nn.Module):
156
+ """ This is the style transform transformer module """
157
+
158
+ def __init__(self,encoder,decoder,PatchEmbed, transformer,args):
159
+
160
+ super().__init__()
161
+ enc_layers = list(encoder.children())
162
+ self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
163
+ self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
164
+ self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
165
+ self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
166
+ self.enc_5 = nn.Sequential(*enc_layers[31:44]) # relu4_1 -> relu5_1
167
+
168
+ for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:
169
+ for param in getattr(self, name).parameters():
170
+ param.requires_grad = False
171
+
172
+ self.mse_loss = nn.MSELoss()
173
+ self.transformer = transformer
174
+ hidden_dim = transformer.d_model
175
+ self.decode = decoder
176
+ self.embedding = PatchEmbed
177
+
178
+ def encode_with_intermediate(self, input):
179
+ results = [input]
180
+ for i in range(5):
181
+ func = getattr(self, 'enc_{:d}'.format(i + 1))
182
+ results.append(func(results[-1]))
183
+ return results[1:]
184
+
185
+ def calc_content_loss(self, input, target):
186
+ assert (input.size() == target.size())
187
+ assert (target.requires_grad is False)
188
+ return self.mse_loss(input, target)
189
+
190
+ def calc_style_loss(self, input, target):
191
+ assert (input.size() == target.size())
192
+ assert (target.requires_grad is False)
193
+ input_mean, input_std = calc_mean_std(input)
194
+ target_mean, target_std = calc_mean_std(target)
195
+ return self.mse_loss(input_mean, target_mean) + \
196
+ self.mse_loss(input_std, target_std)
197
+ def forward(self, samples_c: NestedTensor,samples_s: NestedTensor):
198
+ """ The forward expects a NestedTensor, which consists of:
199
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
200
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
201
+
202
+ """
203
+ content_input = samples_c
204
+ style_input = samples_s
205
+ if isinstance(samples_c, (list, torch.Tensor)):
206
+ samples_c = nested_tensor_from_tensor_list(samples_c) # support different-sized images padding is used for mask [tensor, mask]
207
+ if isinstance(samples_s, (list, torch.Tensor)):
208
+ samples_s = nested_tensor_from_tensor_list(samples_s)
209
+
210
+ # ### features used to calcate loss
211
+ content_feats = self.encode_with_intermediate(samples_c.tensors)
212
+ style_feats = self.encode_with_intermediate(samples_s.tensors)
213
+
214
+ ### Linear projection
215
+ print(f"Embedding Content Input: {samples_c.tensors.shape}")
216
+ style = self.embedding(samples_s.tensors)
217
+ print(f"Style Output: {style.shape}")
218
+ content = self.embedding(samples_c.tensors)
219
+ print(f"Embedding Content Output: {content.shape}")
220
+
221
+ # postional embedding is calculated in transformer.py
222
+ pos_s = None
223
+ pos_c = None
224
+
225
+ mask = None
226
+ hs = self.transformer(style, mask , content, pos_c, pos_s)
227
+ Ics = self.decode(hs)
228
+
229
+ Ics_feats = self.encode_with_intermediate(Ics)
230
+ loss_c = self.calc_content_loss(normal(Ics_feats[-1]), normal(content_feats[-1]))+self.calc_content_loss(normal(Ics_feats[-2]), normal(content_feats[-2]))
231
+ # Style loss
232
+ loss_s = self.calc_style_loss(Ics_feats[0], style_feats[0])
233
+ for i in range(1, 5):
234
+ loss_s += self.calc_style_loss(Ics_feats[i], style_feats[i])
235
+
236
+
237
+ Icc = self.decode(self.transformer(content, mask , content, pos_c, pos_c))
238
+ Iss = self.decode(self.transformer(style, mask , style, pos_s, pos_s))
239
+
240
+ #Identity losses lambda 1
241
+ loss_lambda1 = self.calc_content_loss(Icc,content_input)+self.calc_content_loss(Iss,style_input)
242
+
243
+ #Identity losses lambda 2
244
+ Icc_feats=self.encode_with_intermediate(Icc)
245
+ Iss_feats=self.encode_with_intermediate(Iss)
246
+ loss_lambda2 = self.calc_content_loss(Icc_feats[0], content_feats[0])+self.calc_content_loss(Iss_feats[0], style_feats[0])
247
+ for i in range(1, 5):
248
+ loss_lambda2 += self.calc_content_loss(Icc_feats[i], content_feats[i])+self.calc_content_loss(Iss_feats[i], style_feats[i])
249
+ # Please select and comment out one of the following two sentences
250
+ return Ics, loss_c, loss_s, loss_lambda1, loss_lambda2 #train
251
+ # return Ics #test
models/ViT_helper.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
5
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
6
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
7
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
8
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
9
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
10
+ 'survival rate' as the argument.
11
+ """
12
+ if drop_prob == 0. or not training:
13
+ return x
14
+ keep_prob = 1 - drop_prob
15
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
16
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
17
+ random_tensor.floor_() # binarize
18
+ output = x.div(keep_prob) * random_tensor
19
+ return output
20
+
21
+
22
+ class DropPath(nn.Module):
23
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
24
+ """
25
+ def __init__(self, drop_prob=None):
26
+ super(DropPath, self).__init__()
27
+ self.drop_prob = drop_prob
28
+
29
+ def forward(self, x):
30
+ return drop_path(x, self.drop_prob, self.training)
31
+
32
+ from itertools import repeat
33
+ TORCH_MAJOR = int(torch.__version__.split('.')[0])
34
+ TORCH_MINOR = int(torch.__version__.split('.')[1])
35
+ if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
36
+ from torch._six import container_abcs,int_classes
37
+ else:
38
+ import collections.abc as container_abcs
39
+ int_classes = int
40
+
41
+
42
+ # From PyTorch internals
43
+ def _ntuple(n):
44
+ def parse(x):
45
+ if isinstance(x, container_abcs.Iterable):
46
+ return x
47
+ return tuple(repeat(x, n))
48
+ return parse
49
+
50
+
51
+ to_1tuple = _ntuple(1)
52
+ to_2tuple = _ntuple(2)
53
+ to_3tuple = _ntuple(3)
54
+ to_4tuple = _ntuple(4)
55
+
56
+
57
+
58
+ import torch
59
+ import math
60
+ import warnings
61
+
62
+
63
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
64
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
65
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
66
+ def norm_cdf(x):
67
+ # Computes standard normal cumulative distribution function
68
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
69
+
70
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
71
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
72
+ "The distribution of values may be incorrect.",
73
+ stacklevel=2)
74
+
75
+ with torch.no_grad():
76
+ # Values are generated by using a truncated uniform distribution and
77
+ # then using the inverse CDF for the normal distribution.
78
+ # Get upper and lower cdf values
79
+ l = norm_cdf((a - mean) / std)
80
+ u = norm_cdf((b - mean) / std)
81
+
82
+ # Uniformly fill tensor with values from [l, u], then translate to
83
+ # [2l-1, 2u-1].
84
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
85
+
86
+ # Use inverse cdf transform for normal distribution to get truncated
87
+ # standard normal
88
+ tensor.erfinv_()
89
+
90
+ # Transform to proper mean, std
91
+ tensor.mul_(std * math.sqrt(2.))
92
+ tensor.add_(mean)
93
+
94
+ # Clamp to ensure it's in the proper range
95
+ tensor.clamp_(min=a, max=b)
96
+ return tensor
97
+
98
+
99
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
100
+ # type: (Tensor, float, float, float, float) -> Tensor
101
+ r"""Fills the input Tensor with values drawn from a truncated
102
+ normal distribution. The values are effectively drawn from the
103
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
104
+ with values outside :math:`[a, b]` redrawn until they are within
105
+ the bounds. The method used for generating the random values works
106
+ best when :math:`a \leq \text{mean} \leq b`.
107
+ Args:
108
+ tensor: an n-dimensional `torch.Tensor`
109
+ mean: the mean of the normal distribution
110
+ std: the standard deviation of the normal distribution
111
+ a: the minimum cutoff value
112
+ b: the maximum cutoff value
113
+ Examples:
114
+ >>> w = torch.empty(3, 5)
115
+ >>> nn.init.trunc_normal_(w)
116
+ """
117
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
models/__pycache__/StyTR.cpython-310.pyc ADDED
Binary file (7.41 kB). View file
 
models/__pycache__/ViT_helper.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
models/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (9.43 kB). View file
 
models/sampler.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.utils import data
3
+
4
+
5
+ def InfiniteSampler(n):
6
+ # i = 0
7
+ i = n - 1
8
+ order = np.random.permutation(n)
9
+ while True:
10
+ yield order[i]
11
+ i += 1
12
+ if i >= n:
13
+ np.random.seed()
14
+ order = np.random.permutation(n)
15
+ i = 0
16
+
17
+
18
+ class InfiniteSamplerWrapper(data.sampler.Sampler):
19
+ def __init__(self, data_source):
20
+ self.num_samples = len(data_source)
21
+
22
+ def __iter__(self):
23
+ return iter(InfiniteSampler(self.num_samples))
24
+
25
+ def __len__(self):
26
+ return 2 ** 31
models/transformer.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional, List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn, Tensor
7
+ from function import normal,normal_style
8
+ import numpy as np
9
+ import os
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
12
+ class Transformer(nn.Module):
13
+
14
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=3,
15
+ num_decoder_layers=3, dim_feedforward=2048, dropout=0.1,
16
+ activation="relu", normalize_before=False,
17
+ return_intermediate_dec=False):
18
+ super().__init__()
19
+
20
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
21
+ dropout, activation, normalize_before)
22
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
23
+ self.encoder_c = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
24
+ self.encoder_s = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
25
+
26
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
27
+ dropout, activation, normalize_before)
28
+ decoder_norm = nn.LayerNorm(d_model)
29
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
30
+ return_intermediate=return_intermediate_dec)
31
+
32
+ self._reset_parameters()
33
+
34
+ self.d_model = d_model
35
+ self.nhead = nhead
36
+
37
+ self.new_ps = nn.Conv2d(512 , 512 , (1,1))
38
+ self.averagepooling = nn.AdaptiveAvgPool2d(18)
39
+
40
+ def _reset_parameters(self):
41
+ for p in self.parameters():
42
+ if p.dim() > 1:
43
+ nn.init.xavier_uniform_(p)
44
+
45
+ def forward(self, style, mask , content, pos_embed_c, pos_embed_s):
46
+
47
+ # content-aware positional embedding
48
+ content_pool = self.averagepooling(content)
49
+ pos_c = self.new_ps(content_pool)
50
+ pos_embed_c = F.interpolate(pos_c, mode='bilinear',size= style.shape[-2:])
51
+
52
+ ###flatten NxCxHxW to HWxNxC
53
+ style = style.flatten(2).permute(2, 0, 1)
54
+ if pos_embed_s is not None:
55
+ pos_embed_s = pos_embed_s.flatten(2).permute(2, 0, 1)
56
+
57
+ content = content.flatten(2).permute(2, 0, 1)
58
+ if pos_embed_c is not None:
59
+ pos_embed_c = pos_embed_c.flatten(2).permute(2, 0, 1)
60
+
61
+
62
+ style = self.encoder_s(style, src_key_padding_mask=mask, pos=pos_embed_s)
63
+ content = self.encoder_c(content, src_key_padding_mask=mask, pos=pos_embed_c)
64
+ hs = self.decoder(content, style, memory_key_padding_mask=mask,
65
+ pos=pos_embed_s, query_pos=pos_embed_c)[0]
66
+
67
+ ### HWxNxC to NxCxHxW to
68
+ N, B, C= hs.shape
69
+ H = int(np.sqrt(N))
70
+ hs = hs.permute(1, 2, 0)
71
+ hs = hs.view(B, C, -1,H)
72
+
73
+ return hs
74
+
75
+
76
+ class TransformerEncoder(nn.Module):
77
+
78
+ def __init__(self, encoder_layer, num_layers, norm=None):
79
+ super().__init__()
80
+ self.layers = _get_clones(encoder_layer, num_layers)
81
+ self.num_layers = num_layers
82
+ self.norm = norm
83
+
84
+ def forward(self, src,
85
+ mask: Optional[Tensor] = None,
86
+ src_key_padding_mask: Optional[Tensor] = None,
87
+ pos: Optional[Tensor] = None):
88
+ output = src
89
+
90
+ for layer in self.layers:
91
+ output = layer(output, src_mask=mask,
92
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
93
+
94
+ if self.norm is not None:
95
+ output = self.norm(output)
96
+
97
+ return output
98
+
99
+
100
+ class TransformerDecoder(nn.Module):
101
+
102
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
103
+ super().__init__()
104
+ self.layers = _get_clones(decoder_layer, num_layers)
105
+ self.num_layers = num_layers
106
+ self.norm = norm
107
+ self.return_intermediate = return_intermediate
108
+
109
+ def forward(self, tgt, memory,
110
+ tgt_mask: Optional[Tensor] = None,
111
+ memory_mask: Optional[Tensor] = None,
112
+ tgt_key_padding_mask: Optional[Tensor] = None,
113
+ memory_key_padding_mask: Optional[Tensor] = None,
114
+ pos: Optional[Tensor] = None,
115
+ query_pos: Optional[Tensor] = None):
116
+ output = tgt
117
+
118
+ intermediate = []
119
+
120
+ for layer in self.layers:
121
+ output = layer(output, memory, tgt_mask=tgt_mask,
122
+ memory_mask=memory_mask,
123
+ tgt_key_padding_mask=tgt_key_padding_mask,
124
+ memory_key_padding_mask=memory_key_padding_mask,
125
+ pos=pos, query_pos=query_pos)
126
+ if self.return_intermediate:
127
+ intermediate.append(self.norm(output))
128
+
129
+ if self.norm is not None:
130
+ output = self.norm(output)
131
+ if self.return_intermediate:
132
+ intermediate.pop()
133
+ intermediate.append(output)
134
+
135
+ if self.return_intermediate:
136
+ return torch.stack(intermediate)
137
+
138
+ return output.unsqueeze(0)
139
+
140
+
141
+ class TransformerEncoderLayer(nn.Module):
142
+
143
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
144
+ activation="relu", normalize_before=False):
145
+ super().__init__()
146
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
147
+ # Implementation of Feedforward model
148
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
149
+ self.dropout = nn.Dropout(dropout)
150
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
151
+
152
+ self.norm1 = nn.LayerNorm(d_model)
153
+ self.norm2 = nn.LayerNorm(d_model)
154
+ self.dropout1 = nn.Dropout(dropout)
155
+ self.dropout2 = nn.Dropout(dropout)
156
+
157
+ self.activation = _get_activation_fn(activation)
158
+ self.normalize_before = normalize_before
159
+
160
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
161
+ return tensor if pos is None else tensor + pos
162
+
163
+ def forward_post(self,
164
+ src,
165
+ src_mask: Optional[Tensor] = None,
166
+ src_key_padding_mask: Optional[Tensor] = None,
167
+ pos: Optional[Tensor] = None):
168
+ q = k = self.with_pos_embed(src, pos)
169
+ # q = k = src
170
+ # print(q.size(),k.size(),src.size())
171
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
172
+ key_padding_mask=src_key_padding_mask)[0]
173
+ src = src + self.dropout1(src2)
174
+ src = self.norm1(src)
175
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
176
+ src = src + self.dropout2(src2)
177
+ src = self.norm2(src)
178
+ return src
179
+
180
+ def forward_pre(self, src,
181
+ src_mask: Optional[Tensor] = None,
182
+ src_key_padding_mask: Optional[Tensor] = None,
183
+ pos: Optional[Tensor] = None):
184
+ src2 = self.norm1(src)
185
+ q = k = self.with_pos_embed(src2, pos)
186
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
187
+ key_padding_mask=src_key_padding_mask)[0]
188
+ src = src + self.dropout1(src2)
189
+ src2 = self.norm2(src)
190
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
191
+ src = src + self.dropout2(src2)
192
+ return src
193
+
194
+ def forward(self, src,
195
+ src_mask: Optional[Tensor] = None,
196
+ src_key_padding_mask: Optional[Tensor] = None,
197
+ pos: Optional[Tensor] = None):
198
+ if self.normalize_before:
199
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
200
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
201
+
202
+
203
+ class TransformerDecoderLayer(nn.Module):
204
+
205
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
206
+ activation="relu", normalize_before=False):
207
+ super().__init__()
208
+ # d_model embedding dim
209
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
210
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
211
+ # Implementation of Feedforward model
212
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
213
+ self.dropout = nn.Dropout(dropout)
214
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
215
+
216
+ self.norm1 = nn.LayerNorm(d_model)
217
+ self.norm2 = nn.LayerNorm(d_model)
218
+ self.norm3 = nn.LayerNorm(d_model)
219
+ self.dropout1 = nn.Dropout(dropout)
220
+ self.dropout2 = nn.Dropout(dropout)
221
+ self.dropout3 = nn.Dropout(dropout)
222
+
223
+ self.activation = _get_activation_fn(activation)
224
+ self.normalize_before = normalize_before
225
+
226
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
227
+ return tensor if pos is None else tensor + pos
228
+
229
+ def forward_post(self, tgt, memory,
230
+ tgt_mask: Optional[Tensor] = None,
231
+ memory_mask: Optional[Tensor] = None,
232
+ tgt_key_padding_mask: Optional[Tensor] = None,
233
+ memory_key_padding_mask: Optional[Tensor] = None,
234
+ pos: Optional[Tensor] = None,
235
+ query_pos: Optional[Tensor] = None):
236
+
237
+
238
+ q = self.with_pos_embed(tgt, query_pos)
239
+ k = self.with_pos_embed(memory, pos)
240
+ v = memory
241
+
242
+ tgt2 = self.self_attn(q, k, v, attn_mask=tgt_mask,
243
+ key_padding_mask=tgt_key_padding_mask)[0]
244
+
245
+ tgt = tgt + self.dropout1(tgt2)
246
+ tgt = self.norm1(tgt)
247
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
248
+ key=self.with_pos_embed(memory, pos),
249
+ value=memory, attn_mask=memory_mask,
250
+ key_padding_mask=memory_key_padding_mask)[0]
251
+ tgt = tgt + self.dropout2(tgt2)
252
+ tgt = self.norm2(tgt)
253
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
254
+ tgt = tgt + self.dropout3(tgt2)
255
+ tgt = self.norm3(tgt)
256
+ return tgt
257
+
258
+ def forward_pre(self, tgt, memory,
259
+ tgt_mask: Optional[Tensor] = None,
260
+ memory_mask: Optional[Tensor] = None,
261
+ tgt_key_padding_mask: Optional[Tensor] = None,
262
+ memory_key_padding_mask: Optional[Tensor] = None,
263
+ pos: Optional[Tensor] = None,
264
+ query_pos: Optional[Tensor] = None):
265
+ tgt2 = self.norm1(tgt)
266
+ q = k = self.with_pos_embed(tgt2, query_pos)
267
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
268
+ key_padding_mask=tgt_key_padding_mask)[0]
269
+
270
+ tgt = tgt + self.dropout1(tgt2)
271
+ tgt2 = self.norm2(tgt)
272
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
273
+ key=self.with_pos_embed(memory, pos),
274
+ value=memory, attn_mask=memory_mask,
275
+ key_padding_mask=memory_key_padding_mask)[0]
276
+
277
+ tgt = tgt + self.dropout2(tgt2)
278
+ tgt2 = self.norm3(tgt)
279
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
280
+ tgt = tgt + self.dropout3(tgt2)
281
+ return tgt
282
+
283
+ def forward(self, tgt, memory,
284
+ tgt_mask: Optional[Tensor] = None,
285
+ memory_mask: Optional[Tensor] = None,
286
+ tgt_key_padding_mask: Optional[Tensor] = None,
287
+ memory_key_padding_mask: Optional[Tensor] = None,
288
+ pos: Optional[Tensor] = None,
289
+ query_pos: Optional[Tensor] = None):
290
+ if self.normalize_before:
291
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
292
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
293
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
294
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
295
+
296
+
297
+ def _get_clones(module, N):
298
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
299
+
300
+
301
+ def build_transformer(args):
302
+ return Transformer(
303
+ d_model=args.hidden_dim,
304
+ dropout=args.dropout,
305
+ nhead=args.nheads,
306
+ dim_feedforward=args.dim_feedforward,
307
+ num_encoder_layers=args.enc_layers,
308
+ num_decoder_layers=args.dec_layers,
309
+ normalize_before=args.pre_norm,
310
+ return_intermediate_dec=True,
311
+ )
312
+
313
+
314
+ def _get_activation_fn(activation):
315
+ """Return an activation function given a string"""
316
+ if activation == "relu":
317
+ return F.relu
318
+ if activation == "gelu":
319
+ return F.gelu
320
+ if activation == "glu":
321
+ return F.glu
322
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
outputs/test/0.jpg ADDED
sampler.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.utils import data
3
+
4
+
5
+ def InfiniteSampler(n):
6
+ # i = 0
7
+ i = n - 1
8
+ order = np.random.permutation(n)
9
+ while True:
10
+ yield order[i]
11
+ i += 1
12
+ if i >= n:
13
+ np.random.seed()
14
+ order = np.random.permutation(n)
15
+ i = 0
16
+
17
+
18
+ class InfiniteSamplerWrapper(data.sampler.Sampler):
19
+ def __init__(self, data_source):
20
+ self.num_samples = len(data_source)
21
+
22
+ def __iter__(self):
23
+ return iter(InfiniteSampler(self.num_samples))
24
+
25
+ def __len__(self):
26
+ return 2 ** 31
style/style_image2.jpg ADDED
test.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ from PIL import Image
7
+ from os.path import basename
8
+ from os.path import splitext
9
+ from torchvision import transforms
10
+ from torchvision.utils import save_image
11
+ from function import calc_mean_std, normal, coral
12
+ import models.transformer as transformer
13
+ import models.StyTR as StyTR
14
+ import matplotlib.pyplot as plt
15
+ from matplotlib import cm
16
+ from function import normal
17
+ import numpy as np
18
+ import time
19
+ def test_transform(size, crop):
20
+ transform_list = []
21
+
22
+ if size != 0:
23
+ transform_list.append(transforms.Resize(size))
24
+ if crop:
25
+ transform_list.append(transforms.CenterCrop(size))
26
+ transform_list.append(transforms.ToTensor())
27
+ transform = transforms.Compose(transform_list)
28
+ return transform
29
+ def style_transform(h,w):
30
+ k = (h,w)
31
+ size = int(np.max(k))
32
+ print(type(size))
33
+ transform_list = []
34
+ transform_list.append(transforms.CenterCrop((h,w)))
35
+ transform_list.append(transforms.ToTensor())
36
+ transform = transforms.Compose(transform_list)
37
+ return transform
38
+
39
+ def content_transform():
40
+
41
+ transform_list = []
42
+ transform_list.append(transforms.ToTensor())
43
+ transform = transforms.Compose(transform_list)
44
+ return transform
45
+
46
+
47
+
48
+ parser = argparse.ArgumentParser()
49
+ # Basic options
50
+ parser.add_argument('--content', type=str,
51
+ help='File path to the content image')
52
+ parser.add_argument('--content_dir', type=str,
53
+ help='Directory path to a batch of content images')
54
+ parser.add_argument('--style', type=str,
55
+ help='File path to the style image, or multiple style \
56
+ images separated by commas if you want to do style \
57
+ interpolation or spatial control')
58
+ parser.add_argument('--style_dir', type=str,
59
+ help='Directory path to a batch of style images')
60
+ parser.add_argument('--output', type=str, default='output',
61
+ help='Directory to save the output image(s)')
62
+ parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth')
63
+ parser.add_argument('--decoder_path', type=str, default='experiments/decoder_iter_160000.pth')
64
+ parser.add_argument('--Trans_path', type=str, default='experiments/transformer_iter_160000.pth')
65
+ parser.add_argument('--embedding_path', type=str, default='experiments/embedding_iter_160000.pth')
66
+
67
+
68
+ parser.add_argument('--style_interpolation_weights', type=str, default="")
69
+ parser.add_argument('--a', type=float, default=1.0)
70
+ parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
71
+ help="Type of positional embedding to use on top of the image features")
72
+ parser.add_argument('--hidden_dim', default=512, type=int,
73
+ help="Size of the embeddings (dimension of the transformer)")
74
+ args = parser.parse_args()
75
+
76
+
77
+
78
+
79
+ # Advanced options
80
+ content_size=512
81
+ style_size=512
82
+ crop='store_true'
83
+ save_ext='.jpg'
84
+ output_path=args.output
85
+ preserve_color='store_true'
86
+ alpha=args.a
87
+
88
+
89
+
90
+
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+
93
+ # Either --content or --content_dir should be given.
94
+ if args.content:
95
+ content_paths = [Path(args.content)]
96
+ else:
97
+ content_dir = Path(args.content_dir)
98
+ content_paths = [f for f in content_dir.glob('*')]
99
+
100
+ # Either --style or --style_dir should be given.
101
+ if args.style:
102
+ style_paths = [Path(args.style)]
103
+ else:
104
+ style_dir = Path(args.style_dir)
105
+ style_paths = [f for f in style_dir.glob('*')]
106
+
107
+ if not os.path.exists(output_path):
108
+ os.mkdir(output_path)
109
+
110
+
111
+ vgg = StyTR.vgg
112
+ vgg.load_state_dict(torch.load(args.vgg))
113
+ vgg = nn.Sequential(*list(vgg.children())[:44])
114
+
115
+ decoder = StyTR.decoder
116
+ Trans = transformer.Transformer()
117
+ embedding = StyTR.PatchEmbed()
118
+
119
+ decoder.eval()
120
+ Trans.eval()
121
+ vgg.eval()
122
+ from collections import OrderedDict
123
+ new_state_dict = OrderedDict()
124
+ state_dict = torch.load(args.decoder_path)
125
+ for k, v in state_dict.items():
126
+ #namekey = k[7:] # remove `module.`
127
+ namekey = k
128
+ new_state_dict[namekey] = v
129
+ decoder.load_state_dict(new_state_dict)
130
+
131
+ new_state_dict = OrderedDict()
132
+ state_dict = torch.load(args.Trans_path)
133
+ for k, v in state_dict.items():
134
+ #namekey = k[7:] # remove `module.`
135
+ namekey = k
136
+ new_state_dict[namekey] = v
137
+ Trans.load_state_dict(new_state_dict)
138
+
139
+ new_state_dict = OrderedDict()
140
+ state_dict = torch.load(args.embedding_path)
141
+ for k, v in state_dict.items():
142
+ #namekey = k[7:] # remove `module.`
143
+ namekey = k
144
+ new_state_dict[namekey] = v
145
+ embedding.load_state_dict(new_state_dict)
146
+
147
+ network = StyTR.StyTrans(vgg,decoder,embedding,Trans,args)
148
+ network.eval()
149
+ network.to(device)
150
+
151
+
152
+
153
+ content_tf = test_transform(content_size, crop)
154
+ style_tf = test_transform(style_size, crop)
155
+
156
+ for content_path in content_paths:
157
+ for style_path in style_paths:
158
+ print(content_path)
159
+
160
+
161
+ content_tf1 = content_transform()
162
+ content = content_tf(Image.open(content_path).convert("RGB"))
163
+
164
+ h,w,c=np.shape(content)
165
+ style_tf1 = style_transform(h,w)
166
+ style = style_tf(Image.open(style_path).convert("RGB"))
167
+
168
+
169
+ style = style.to(device).unsqueeze(0)
170
+ content = content.to(device).unsqueeze(0)
171
+
172
+ with torch.no_grad():
173
+ output = network(content, style)[0]
174
+ output = output.cpu()
175
+
176
+ output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
177
+ output_path, splitext(basename(content_path))[0],
178
+ splitext(basename(style_path))[0], save_ext
179
+ )
180
+
181
+ save_image(output, output_name)
182
+
183
+
train.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.utils.data as data
6
+ from PIL import Image
7
+ from PIL import ImageFile
8
+ from tensorboardX import SummaryWriter
9
+ from torchvision import transforms
10
+ from tqdm import tqdm
11
+ from pathlib import Path
12
+ import models.transformer as transformer
13
+ import models.StyTR as StyTR
14
+ from sampler import InfiniteSamplerWrapper
15
+ from torchvision.utils import save_image
16
+
17
+
18
+ def train_transform():
19
+ transform_list = [
20
+ transforms.Resize(size=(512, 512)),
21
+ transforms.RandomCrop(256),
22
+ transforms.ToTensor()
23
+ ]
24
+ return transforms.Compose(transform_list)
25
+
26
+
27
+ class FlatFolderDataset(data.Dataset):
28
+ def __init__(self, root, transform):
29
+ super(FlatFolderDataset, self).__init__()
30
+ self.root = root
31
+ print(self.root)
32
+ self.path = os.listdir(self.root)
33
+ if os.path.isdir(os.path.join(self.root,self.path[0])):
34
+ self.paths = []
35
+ for file_name in os.listdir(self.root):
36
+ for file_name1 in os.listdir(os.path.join(self.root,file_name)):
37
+ self.paths.append(self.root+"/"+file_name+"/"+file_name1)
38
+ else:
39
+ self.paths = list(Path(self.root).glob('*'))
40
+ self.transform = transform
41
+ def __getitem__(self, index):
42
+ path = self.paths[index]
43
+ img = Image.open(str(path)).convert('RGB')
44
+ img = self.transform(img)
45
+ return img
46
+ def __len__(self):
47
+ return len(self.paths)
48
+ def name(self):
49
+ return 'FlatFolderDataset'
50
+
51
+ def adjust_learning_rate(optimizer, iteration_count):
52
+ """Imitating the original implementation"""
53
+ lr = 2e-4 / (1.0 + args.lr_decay * (iteration_count - 1e4))
54
+ for param_group in optimizer.param_groups:
55
+ param_group['lr'] = lr
56
+
57
+ def warmup_learning_rate(optimizer, iteration_count):
58
+ """Imitating the original implementation"""
59
+ lr = args.lr * 0.1 * (1.0 + 3e-4 * iteration_count)
60
+ # print(lr)
61
+ for param_group in optimizer.param_groups:
62
+ param_group['lr'] = lr
63
+
64
+
65
+ parser = argparse.ArgumentParser()
66
+ # Basic options
67
+ parser.add_argument('--content_dir', default=r'E:\NLP\VAL_Transformers\models\StyTr2\images', type=str,
68
+ help='Directory path to a batch of content images')
69
+ parser.add_argument('--style_dir', default=r'E:\NLP\VAL_Transformers\models\StyTr2\style', type=str, #wikiart dataset crawled from https://www.wikiart.org/
70
+ help='Directory path to a batch of style images')
71
+ parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth') #run the train.py, please download the pretrained vgg checkpoint
72
+
73
+ # training options
74
+ parser.add_argument('--save_dir', default='./experiments',
75
+ help='Directory to save the model')
76
+ parser.add_argument('--log_dir', default='./logs',
77
+ help='Directory to save the log')
78
+ parser.add_argument('--lr', type=float, default=5e-4)
79
+ parser.add_argument('--lr_decay', type=float, default=1e-5)
80
+ parser.add_argument('--max_iter', type=int, default=160000)
81
+ parser.add_argument('--batch_size', type=int, default=8)
82
+ parser.add_argument('--style_weight', type=float, default=10.0)
83
+ parser.add_argument('--content_weight', type=float, default=7.0)
84
+ parser.add_argument('--n_threads', type=int, default=1)
85
+ parser.add_argument('--save_model_interval', type=int, default=10000)
86
+ parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
87
+ help="Type of positional embedding to use on top of the image features")
88
+ parser.add_argument('--hidden_dim', default=512, type=int,
89
+ help="Size of the embeddings (dimension of the transformer)")
90
+ args = parser.parse_args()
91
+
92
+ USE_CUDA = torch.cuda.is_available()
93
+ device = torch.device("cuda" if USE_CUDA else "cpu")
94
+ print(device)
95
+
96
+ if not os.path.exists(args.save_dir):
97
+ os.makedirs(args.save_dir)
98
+
99
+ if not os.path.exists(args.log_dir):
100
+ os.mkdir(args.log_dir)
101
+ writer = SummaryWriter(log_dir=args.log_dir)
102
+
103
+ vgg = StyTR.vgg
104
+ vgg.load_state_dict(torch.load(args.vgg))
105
+ vgg = nn.Sequential(*list(vgg.children())[:44])
106
+
107
+ decoder = StyTR.decoder
108
+ embedding = StyTR.PatchEmbed()
109
+
110
+ Trans = transformer.Transformer()
111
+ with torch.no_grad():
112
+ network = StyTR.StyTrans(vgg,decoder,embedding, Trans,args)
113
+ network.train()
114
+
115
+ network.to(device)
116
+ network = nn.DataParallel(network, device_ids=[0,1])
117
+ content_tf = train_transform()
118
+ style_tf = train_transform()
119
+
120
+
121
+
122
+ content_dataset = FlatFolderDataset(args.content_dir, content_tf)
123
+ style_dataset = FlatFolderDataset(args.style_dir, style_tf)
124
+
125
+ content_iter = iter(data.DataLoader(
126
+ content_dataset, batch_size=args.batch_size,
127
+ sampler=InfiniteSamplerWrapper(content_dataset),
128
+ num_workers=args.n_threads))
129
+ style_iter = iter(data.DataLoader(
130
+ style_dataset, batch_size=args.batch_size,
131
+ sampler=InfiniteSamplerWrapper(style_dataset),
132
+ num_workers=args.n_threads))
133
+
134
+
135
+ optimizer = torch.optim.Adam([
136
+ {'params': network.module.transformer.parameters()},
137
+ {'params': network.module.decode.parameters()},
138
+ {'params': network.module.embedding.parameters()},
139
+ ], lr=args.lr)
140
+
141
+
142
+ if not os.path.exists(args.save_dir+"/test"):
143
+ os.makedirs(args.save_dir+"/test")
144
+
145
+
146
+
147
+ for i in tqdm(range(args.max_iter)):
148
+
149
+ if i < 1e4:
150
+ warmup_learning_rate(optimizer, iteration_count=i)
151
+ else:
152
+ adjust_learning_rate(optimizer, iteration_count=i)
153
+
154
+ # print('learning_rate: %s' % str(optimizer.param_groups[0]['lr']))
155
+ content_images = next(content_iter).to(device)
156
+ style_images = next(style_iter).to(device)
157
+ out, loss_c, loss_s,l_identity1, l_identity2 = network(content_images, style_images)
158
+
159
+ if i % 100 == 0:
160
+ output_name = '{:s}/test/{:s}{:s}'.format(
161
+ args.save_dir, str(i),".jpg"
162
+ )
163
+ out = torch.cat((content_images,out),0)
164
+ out = torch.cat((style_images,out),0)
165
+ save_image(out, output_name)
166
+
167
+
168
+ loss_c = args.content_weight * loss_c
169
+ loss_s = args.style_weight * loss_s
170
+ loss = loss_c + loss_s + (l_identity1 * 70) + (l_identity2 * 1)
171
+
172
+ print(loss.sum().cpu().detach().numpy(),"-content:",loss_c.sum().cpu().detach().numpy(),"-style:",loss_s.sum().cpu().detach().numpy()
173
+ ,"-l1:",l_identity1.sum().cpu().detach().numpy(),"-l2:",l_identity2.sum().cpu().detach().numpy()
174
+ )
175
+
176
+ optimizer.zero_grad()
177
+ loss.sum().backward()
178
+ optimizer.step()
179
+
180
+ writer.add_scalar('loss_content', loss_c.sum().item(), i + 1)
181
+ writer.add_scalar('loss_style', loss_s.sum().item(), i + 1)
182
+ writer.add_scalar('loss_identity1', l_identity1.sum().item(), i + 1)
183
+ writer.add_scalar('loss_identity2', l_identity2.sum().item(), i + 1)
184
+ writer.add_scalar('total_loss', loss.sum().item(), i + 1)
185
+
186
+ if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
187
+ state_dict = network.module.transformer.state_dict()
188
+ for key in state_dict.keys():
189
+ state_dict[key] = state_dict[key].to(torch.device('cpu'))
190
+ torch.save(state_dict,
191
+ '{:s}/transformer_iter_{:d}.pth'.format(args.save_dir,
192
+ i + 1))
193
+
194
+ state_dict = network.module.decode.state_dict()
195
+ for key in state_dict.keys():
196
+ state_dict[key] = state_dict[key].to(torch.device('cpu'))
197
+ torch.save(state_dict,
198
+ '{:s}/decoder_iter_{:d}.pth'.format(args.save_dir,
199
+ i + 1))
200
+ state_dict = network.module.embedding.state_dict()
201
+ for key in state_dict.keys():
202
+ state_dict[key] = state_dict[key].to(torch.device('cpu'))
203
+ torch.save(state_dict,
204
+ '{:s}/embedding_iter_{:d}.pth'.format(args.save_dir,
205
+ i + 1))
206
+
207
+
208
+ writer.close()
209
+
210
+
util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
util/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (148 Bytes). View file
 
util/__pycache__/box_ops.cpython-310.pyc ADDED
Binary file (2.72 kB). View file
 
util/__pycache__/misc.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
util/box_ops.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Utilities for bounding box manipulation and GIoU.
4
+ """
5
+ import torch
6
+ from torchvision.ops.boxes import box_area
7
+
8
+
9
+ def box_cxcywh_to_xyxy(x):
10
+ x_c, y_c, w, h = x.unbind(-1)
11
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
13
+ return torch.stack(b, dim=-1)
14
+
15
+
16
+ def box_xyxy_to_cxcywh(x):
17
+ x0, y0, x1, y1 = x.unbind(-1)
18
+ b = [(x0 + x1) / 2, (y0 + y1) / 2,
19
+ (x1 - x0), (y1 - y0)]
20
+ return torch.stack(b, dim=-1)
21
+
22
+
23
+ # modified from torchvision to also return the union
24
+ def box_iou(boxes1, boxes2):
25
+ area1 = box_area(boxes1)
26
+ area2 = box_area(boxes2)
27
+
28
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
29
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
30
+
31
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
32
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
33
+
34
+ union = area1[:, None] + area2 - inter
35
+
36
+ iou = inter / union
37
+ return iou, union
38
+
39
+
40
+ def generalized_box_iou(boxes1, boxes2):
41
+ """
42
+ Generalized IoU from https://giou.stanford.edu/
43
+
44
+ The boxes should be in [x0, y0, x1, y1] format
45
+
46
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
47
+ and M = len(boxes2)
48
+ """
49
+ # degenerate boxes gives inf / nan results
50
+ # so do an early check
51
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
52
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
53
+ iou, union = box_iou(boxes1, boxes2)
54
+
55
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
56
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
57
+
58
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
59
+ area = wh[:, :, 0] * wh[:, :, 1]
60
+
61
+ return iou - (area - union) / area
62
+
63
+
64
+ def masks_to_boxes(masks):
65
+ """Compute the bounding boxes around the provided masks
66
+
67
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
68
+
69
+ Returns a [N, 4] tensors, with the boxes in xyxy format
70
+ """
71
+ if masks.numel() == 0:
72
+ return torch.zeros((0, 4), device=masks.device)
73
+
74
+ h, w = masks.shape[-2:]
75
+
76
+ y = torch.arange(0, h, dtype=torch.float)
77
+ x = torch.arange(0, w, dtype=torch.float)
78
+ y, x = torch.meshgrid(y, x)
79
+
80
+ x_mask = (masks * x.unsqueeze(0))
81
+ x_max = x_mask.flatten(1).max(-1)[0]
82
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
83
+
84
+ y_mask = (masks * y.unsqueeze(0))
85
+ y_max = y_mask.flatten(1).max(-1)[0]
86
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
87
+
88
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
util/misc.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Misc functions, including distributed helpers.
4
+
5
+ Mostly copy-paste from torchvision references.
6
+ """
7
+ import os
8
+ import subprocess
9
+ import time
10
+ from collections import defaultdict, deque
11
+ import datetime
12
+ import pickle
13
+ from packaging import version
14
+ from typing import Optional, List
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ from torch import Tensor
19
+
20
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
21
+ import torchvision
22
+ if version.parse(torchvision.__version__) < version.parse('0.7'):
23
+ from torchvision.ops import _new_empty_tensor
24
+ from torchvision.ops.misc import _output_size
25
+
26
+
27
+ class SmoothedValue(object):
28
+ """Track a series of values and provide access to smoothed values over a
29
+ window or the global series average.
30
+ """
31
+
32
+ def __init__(self, window_size=20, fmt=None):
33
+ if fmt is None:
34
+ fmt = "{median:.4f} ({global_avg:.4f})"
35
+ self.deque = deque(maxlen=window_size)
36
+ self.total = 0.0
37
+ self.count = 0
38
+ self.fmt = fmt
39
+
40
+ def update(self, value, n=1):
41
+ self.deque.append(value)
42
+ self.count += n
43
+ self.total += value * n
44
+
45
+ def synchronize_between_processes(self):
46
+ """
47
+ Warning: does not synchronize the deque!
48
+ """
49
+ if not is_dist_avail_and_initialized():
50
+ return
51
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
52
+ dist.barrier()
53
+ dist.all_reduce(t)
54
+ t = t.tolist()
55
+ self.count = int(t[0])
56
+ self.total = t[1]
57
+
58
+ @property
59
+ def median(self):
60
+ d = torch.tensor(list(self.deque))
61
+ return d.median().item()
62
+
63
+ @property
64
+ def avg(self):
65
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
66
+ return d.mean().item()
67
+
68
+ @property
69
+ def global_avg(self):
70
+ return self.total / self.count
71
+
72
+ @property
73
+ def max(self):
74
+ return max(self.deque)
75
+
76
+ @property
77
+ def value(self):
78
+ return self.deque[-1]
79
+
80
+ def __str__(self):
81
+ return self.fmt.format(
82
+ median=self.median,
83
+ avg=self.avg,
84
+ global_avg=self.global_avg,
85
+ max=self.max,
86
+ value=self.value)
87
+
88
+
89
+ def all_gather(data):
90
+ """
91
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
92
+ Args:
93
+ data: any picklable object
94
+ Returns:
95
+ list[data]: list of data gathered from each rank
96
+ """
97
+ world_size = get_world_size()
98
+ if world_size == 1:
99
+ return [data]
100
+
101
+ # serialized to a Tensor
102
+ buffer = pickle.dumps(data)
103
+ storage = torch.ByteStorage.from_buffer(buffer)
104
+ tensor = torch.ByteTensor(storage).to("cuda")
105
+
106
+ # obtain Tensor size of each rank
107
+ local_size = torch.tensor([tensor.numel()], device="cuda")
108
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
109
+ dist.all_gather(size_list, local_size)
110
+ size_list = [int(size.item()) for size in size_list]
111
+ max_size = max(size_list)
112
+
113
+ # receiving Tensor from all ranks
114
+ # we pad the tensor because torch all_gather does not support
115
+ # gathering tensors of different shapes
116
+ tensor_list = []
117
+ for _ in size_list:
118
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
119
+ if local_size != max_size:
120
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
121
+ tensor = torch.cat((tensor, padding), dim=0)
122
+ dist.all_gather(tensor_list, tensor)
123
+
124
+ data_list = []
125
+ for size, tensor in zip(size_list, tensor_list):
126
+ buffer = tensor.cpu().numpy().tobytes()[:size]
127
+ data_list.append(pickle.loads(buffer))
128
+
129
+ return data_list
130
+
131
+
132
+ def reduce_dict(input_dict, average=True):
133
+ """
134
+ Args:
135
+ input_dict (dict): all the values will be reduced
136
+ average (bool): whether to do average or sum
137
+ Reduce the values in the dictionary from all processes so that all processes
138
+ have the averaged results. Returns a dict with the same fields as
139
+ input_dict, after reduction.
140
+ """
141
+ world_size = get_world_size()
142
+ if world_size < 2:
143
+ return input_dict
144
+ with torch.no_grad():
145
+ names = []
146
+ values = []
147
+ # sort the keys so that they are consistent across processes
148
+ for k in sorted(input_dict.keys()):
149
+ names.append(k)
150
+ values.append(input_dict[k])
151
+ values = torch.stack(values, dim=0)
152
+ dist.all_reduce(values)
153
+ if average:
154
+ values /= world_size
155
+ reduced_dict = {k: v for k, v in zip(names, values)}
156
+ return reduced_dict
157
+
158
+
159
+ class MetricLogger(object):
160
+ def __init__(self, delimiter="\t"):
161
+ self.meters = defaultdict(SmoothedValue)
162
+ self.delimiter = delimiter
163
+
164
+ def update(self, **kwargs):
165
+ for k, v in kwargs.items():
166
+ if isinstance(v, torch.Tensor):
167
+ v = v.item()
168
+ assert isinstance(v, (float, int))
169
+ self.meters[k].update(v)
170
+
171
+ def __getattr__(self, attr):
172
+ if attr in self.meters:
173
+ return self.meters[attr]
174
+ if attr in self.__dict__:
175
+ return self.__dict__[attr]
176
+ raise AttributeError("'{}' object has no attribute '{}'".format(
177
+ type(self).__name__, attr))
178
+
179
+ def __str__(self):
180
+ loss_str = []
181
+ for name, meter in self.meters.items():
182
+ loss_str.append(
183
+ "{}: {}".format(name, str(meter))
184
+ )
185
+ return self.delimiter.join(loss_str)
186
+
187
+ def synchronize_between_processes(self):
188
+ for meter in self.meters.values():
189
+ meter.synchronize_between_processes()
190
+
191
+ def add_meter(self, name, meter):
192
+ self.meters[name] = meter
193
+
194
+ def log_every(self, iterable, print_freq, header=None):
195
+ i = 0
196
+ if not header:
197
+ header = ''
198
+ start_time = time.time()
199
+ end = time.time()
200
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
201
+ data_time = SmoothedValue(fmt='{avg:.4f}')
202
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
203
+ if torch.cuda.is_available():
204
+ log_msg = self.delimiter.join([
205
+ header,
206
+ '[{0' + space_fmt + '}/{1}]',
207
+ 'eta: {eta}',
208
+ '{meters}',
209
+ 'time: {time}',
210
+ 'data: {data}',
211
+ 'max mem: {memory:.0f}'
212
+ ])
213
+ else:
214
+ log_msg = self.delimiter.join([
215
+ header,
216
+ '[{0' + space_fmt + '}/{1}]',
217
+ 'eta: {eta}',
218
+ '{meters}',
219
+ 'time: {time}',
220
+ 'data: {data}'
221
+ ])
222
+ MB = 1024.0 * 1024.0
223
+ for obj in iterable:
224
+ data_time.update(time.time() - end)
225
+ yield obj
226
+ iter_time.update(time.time() - end)
227
+ if i % print_freq == 0 or i == len(iterable) - 1:
228
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
229
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
230
+ if torch.cuda.is_available():
231
+ print(log_msg.format(
232
+ i, len(iterable), eta=eta_string,
233
+ meters=str(self),
234
+ time=str(iter_time), data=str(data_time),
235
+ memory=torch.cuda.max_memory_allocated() / MB))
236
+ else:
237
+ print(log_msg.format(
238
+ i, len(iterable), eta=eta_string,
239
+ meters=str(self),
240
+ time=str(iter_time), data=str(data_time)))
241
+ i += 1
242
+ end = time.time()
243
+ total_time = time.time() - start_time
244
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
245
+ print('{} Total time: {} ({:.4f} s / it)'.format(
246
+ header, total_time_str, total_time / len(iterable)))
247
+
248
+
249
+ def get_sha():
250
+ cwd = os.path.dirname(os.path.abspath(__file__))
251
+
252
+ def _run(command):
253
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
254
+ sha = 'N/A'
255
+ diff = "clean"
256
+ branch = 'N/A'
257
+ try:
258
+ sha = _run(['git', 'rev-parse', 'HEAD'])
259
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
260
+ diff = _run(['git', 'diff-index', 'HEAD'])
261
+ diff = "has uncommited changes" if diff else "clean"
262
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
263
+ except Exception:
264
+ pass
265
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
266
+ return message
267
+
268
+
269
+ def collate_fn(batch):
270
+ batch = list(zip(*batch))
271
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
272
+ return tuple(batch)
273
+
274
+
275
+ def _max_by_axis(the_list):
276
+ # type: (List[List[int]]) -> List[int]
277
+ maxes = the_list[0]
278
+ for sublist in the_list[1:]:
279
+ for index, item in enumerate(sublist):
280
+ maxes[index] = max(maxes[index], item)
281
+ return maxes
282
+
283
+
284
+ class NestedTensor(object):
285
+ def __init__(self, tensors, mask: Optional[Tensor]):
286
+ self.tensors = tensors
287
+ self.mask = mask
288
+
289
+ def to(self, device):
290
+ # type: (Device) -> NestedTensor # noqa
291
+ cast_tensor = self.tensors.to(device)
292
+ mask = self.mask
293
+ if mask is not None:
294
+ assert mask is not None
295
+ cast_mask = mask.to(device)
296
+ else:
297
+ cast_mask = None
298
+ return NestedTensor(cast_tensor, cast_mask)
299
+
300
+ def decompose(self):
301
+ return self.tensors, self.mask
302
+
303
+ def __repr__(self):
304
+ return str(self.tensors)
305
+
306
+
307
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
308
+ # TODO make this more general
309
+ if tensor_list[0].ndim == 3:
310
+ if torchvision._is_tracing():
311
+ # nested_tensor_from_tensor_list() does not export well to ONNX
312
+ # call _onnx_nested_tensor_from_tensor_list() instead
313
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
314
+
315
+ # TODO make it support different-sized images
316
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
317
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
318
+ batch_shape = [len(tensor_list)] + max_size
319
+ b, c, h, w = batch_shape
320
+ dtype = tensor_list[0].dtype
321
+ device = tensor_list[0].device
322
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
323
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
324
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
325
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
326
+ m[: img.shape[1], :img.shape[2]] = False
327
+ else:
328
+ raise ValueError('not supported')
329
+ return NestedTensor(tensor, mask)
330
+
331
+
332
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
333
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
334
+ @torch.jit.unused
335
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
336
+ max_size = []
337
+ for i in range(tensor_list[0].dim()):
338
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
339
+ max_size.append(max_size_i)
340
+ max_size = tuple(max_size)
341
+
342
+ # work around for
343
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
344
+ # m[: img.shape[1], :img.shape[2]] = False
345
+ # which is not yet supported in onnx
346
+ padded_imgs = []
347
+ padded_masks = []
348
+ for img in tensor_list:
349
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
350
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
351
+ padded_imgs.append(padded_img)
352
+
353
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
354
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
355
+ padded_masks.append(padded_mask.to(torch.bool))
356
+
357
+ tensor = torch.stack(padded_imgs)
358
+ mask = torch.stack(padded_masks)
359
+
360
+ return NestedTensor(tensor, mask=mask)
361
+
362
+
363
+ def setup_for_distributed(is_master):
364
+ """
365
+ This function disables printing when not in master process
366
+ """
367
+ import builtins as __builtin__
368
+ builtin_print = __builtin__.print
369
+
370
+ def print(*args, **kwargs):
371
+ force = kwargs.pop('force', False)
372
+ if is_master or force:
373
+ builtin_print(*args, **kwargs)
374
+
375
+ __builtin__.print = print
376
+
377
+
378
+ def is_dist_avail_and_initialized():
379
+ if not dist.is_available():
380
+ return False
381
+ if not dist.is_initialized():
382
+ return False
383
+ return True
384
+
385
+
386
+ def get_world_size():
387
+ if not is_dist_avail_and_initialized():
388
+ return 1
389
+ return dist.get_world_size()
390
+
391
+
392
+ def get_rank():
393
+ if not is_dist_avail_and_initialized():
394
+ return 0
395
+ return dist.get_rank()
396
+
397
+
398
+ def is_main_process():
399
+ return get_rank() == 0
400
+
401
+
402
+ def save_on_master(*args, **kwargs):
403
+ if is_main_process():
404
+ torch.save(*args, **kwargs)
405
+
406
+
407
+ def init_distributed_mode(args):
408
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
409
+ args.rank = int(os.environ["RANK"])
410
+ args.world_size = int(os.environ['WORLD_SIZE'])
411
+ args.gpu = int(os.environ['LOCAL_RANK'])
412
+ elif 'SLURM_PROCID' in os.environ:
413
+ args.rank = int(os.environ['SLURM_PROCID'])
414
+ args.gpu = args.rank % torch.cuda.device_count()
415
+ else:
416
+ print('Not using distributed mode')
417
+ args.distributed = False
418
+ return
419
+
420
+ args.distributed = True
421
+
422
+ torch.cuda.set_device(args.gpu)
423
+ args.dist_backend = 'nccl'
424
+ print('| distributed init (rank {}): {}'.format(
425
+ args.rank, args.dist_url), flush=True)
426
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
427
+ world_size=args.world_size, rank=args.rank)
428
+ torch.distributed.barrier()
429
+ setup_for_distributed(args.rank == 0)
430
+
431
+
432
+ @torch.no_grad()
433
+ def accuracy(output, target, topk=(1,)):
434
+ """Computes the precision@k for the specified values of k"""
435
+ if target.numel() == 0:
436
+ return [torch.zeros([], device=output.device)]
437
+ maxk = max(topk)
438
+ batch_size = target.size(0)
439
+
440
+ _, pred = output.topk(maxk, 1, True, True)
441
+ pred = pred.t()
442
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
443
+
444
+ res = []
445
+ for k in topk:
446
+ correct_k = correct[:k].view(-1).float().sum(0)
447
+ res.append(correct_k.mul_(100.0 / batch_size))
448
+ return res
449
+
450
+
451
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
452
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
453
+ """
454
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
455
+ This will eventually be supported natively by PyTorch, and this
456
+ class can go away.
457
+ """
458
+ if version.parse(torchvision.__version__) < version.parse('0.7'):
459
+ if input.numel() > 0:
460
+ return torch.nn.functional.interpolate(
461
+ input, size, scale_factor, mode, align_corners
462
+ )
463
+
464
+ output_shape = _output_size(2, input, size, scale_factor)
465
+ output_shape = list(input.shape[:-2]) + list(output_shape)
466
+ return _new_empty_tensor(input, output_shape)
467
+ else:
468
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
util/plot_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Plotting utilities to visualize training logs.
3
+ """
4
+ import torch
5
+ import pandas as pd
6
+ import numpy as np
7
+ import seaborn as sns
8
+ import matplotlib.pyplot as plt
9
+
10
+ from pathlib import Path, PurePath
11
+
12
+
13
+ def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
14
+ '''
15
+ Function to plot specific fields from training log(s). Plots both training and test results.
16
+
17
+ :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
18
+ - fields = which results to plot from each log file - plots both training and test for each field.
19
+ - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
20
+ - log_name = optional, name of log file if different than default 'log.txt'.
21
+
22
+ :: Outputs - matplotlib plots of results in fields, color coded for each log file.
23
+ - solid lines are training results, dashed lines are test results.
24
+
25
+ '''
26
+ func_name = "plot_utils.py::plot_logs"
27
+
28
+ # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
29
+ # convert single Path to list to avoid 'not iterable' error
30
+
31
+ if not isinstance(logs, list):
32
+ if isinstance(logs, PurePath):
33
+ logs = [logs]
34
+ print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
35
+ else:
36
+ raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
37
+ Expect list[Path] or single Path obj, received {type(logs)}")
38
+
39
+ # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
40
+ for i, dir in enumerate(logs):
41
+ if not isinstance(dir, PurePath):
42
+ raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
43
+ if not dir.exists():
44
+ raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
45
+ # verify log_name exists
46
+ fn = Path(dir / log_name)
47
+ if not fn.exists():
48
+ print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
49
+ print(f"--> full path of missing log file: {fn}")
50
+ return
51
+
52
+ # load log file(s) and plot
53
+ dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
54
+
55
+ fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
56
+
57
+ for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
58
+ for j, field in enumerate(fields):
59
+ if field == 'mAP':
60
+ coco_eval = pd.DataFrame(
61
+ np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
62
+ ).ewm(com=ewm_col).mean()
63
+ axs[j].plot(coco_eval, c=color)
64
+ else:
65
+ df.interpolate().ewm(com=ewm_col).mean().plot(
66
+ y=[f'train_{field}', f'test_{field}'],
67
+ ax=axs[j],
68
+ color=[color] * 2,
69
+ style=['-', '--']
70
+ )
71
+ for ax, field in zip(axs, fields):
72
+ ax.legend([Path(p).name for p in logs])
73
+ ax.set_title(field)
74
+
75
+
76
+ def plot_precision_recall(files, naming_scheme='iter'):
77
+ if naming_scheme == 'exp_id':
78
+ # name becomes exp_id
79
+ names = [f.parts[-3] for f in files]
80
+ elif naming_scheme == 'iter':
81
+ names = [f.stem for f in files]
82
+ else:
83
+ raise ValueError(f'not supported {naming_scheme}')
84
+ fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
85
+ for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
86
+ data = torch.load(f)
87
+ # precision is n_iou, n_points, n_cat, n_area, max_det
88
+ precision = data['precision']
89
+ recall = data['params'].recThrs
90
+ scores = data['scores']
91
+ # take precision for all classes, all areas and 100 detections
92
+ precision = precision[0, :, :, 0, -1].mean(1)
93
+ scores = scores[0, :, :, 0, -1].mean(1)
94
+ prec = precision.mean()
95
+ rec = data['recall'][0, :, 0, -1].mean()
96
+ print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
97
+ f'score={scores.mean():0.3f}, ' +
98
+ f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
99
+ )
100
+ axs[0].plot(recall, precision, c=color)
101
+ axs[1].plot(recall, scores, c=color)
102
+
103
+ axs[0].set_title('Precision / Recall')
104
+ axs[0].legend(names)
105
+ axs[1].set_title('Scores / Recall')
106
+ axs[1].legend(names)
107
+ return fig, axs