Dat Nguyen-Tien
commited on
Commit
·
a38a6fb
1
Parent(s):
476c59d
upload_huggff
Browse files- __init__.py +1 -0
- __pycache__/__init__.cpython-310.pyc +0 -0
- __pycache__/function.cpython-310.pyc +0 -0
- __pycache__/sampler.cpython-310.pyc +0 -0
- experiments/decoder_iter_160000.pth +3 -0
- experiments/embedding_iter_160000.pth +3 -0
- experiments/transformer_iter_160000.pth +3 -0
- experiments/vgg_normalised.pth +3 -0
- function.py +73 -0
- images/images_image1.jpg +0 -0
- logs/events.out.tfevents.1717260856.NGUYENTIENDAT +3 -0
- logs/events.out.tfevents.1717261358.NGUYENTIENDAT +3 -0
- logs/events.out.tfevents.1717261528.NGUYENTIENDAT +3 -0
- logs/events.out.tfevents.1717262282.NGUYENTIENDAT +3 -0
- logs/events.out.tfevents.1717262831.NGUYENTIENDAT +3 -0
- logs/events.out.tfevents.1717262870.NGUYENTIENDAT +3 -0
- logs/events.out.tfevents.1717300182.NGUYENTIENDAT +3 -0
- logs/events.out.tfevents.1717300226.NGUYENTIENDAT +3 -0
- logs/events.out.tfevents.1717300229.NGUYENTIENDAT +3 -0
- logs/events.out.tfevents.1717300284.NGUYENTIENDAT +3 -0
- logs/events.out.tfevents.1717300608.NGUYENTIENDAT +3 -0
- logs/events.out.tfevents.1717300611.NGUYENTIENDAT +3 -0
- models/StyTR.py +251 -0
- models/ViT_helper.py +117 -0
- models/__pycache__/StyTR.cpython-310.pyc +0 -0
- models/__pycache__/ViT_helper.cpython-310.pyc +0 -0
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/transformer.cpython-310.pyc +0 -0
- models/sampler.py +26 -0
- models/transformer.py +322 -0
- outputs/test/0.jpg +0 -0
- sampler.py +26 -0
- style/style_image2.jpg +0 -0
- test.py +183 -0
- train.py +210 -0
- util/__init__.py +1 -0
- util/__pycache__/__init__.cpython-310.pyc +0 -0
- util/__pycache__/box_ops.cpython-310.pyc +0 -0
- util/__pycache__/misc.cpython-310.pyc +0 -0
- util/box_ops.py +88 -0
- util/misc.py +468 -0
- util/plot_utils.py +107 -0
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (142 Bytes). View file
|
|
__pycache__/function.cpython-310.pyc
ADDED
Binary file (2.17 kB). View file
|
|
__pycache__/sampler.cpython-310.pyc
ADDED
Binary file (1.12 kB). View file
|
|
experiments/decoder_iter_160000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57133e62081041cfc8c16d921071832837844f69a01b1705c7511faa2d2b9eee
|
3 |
+
size 14027089
|
experiments/embedding_iter_160000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f56cb10d331e980654469c80715da5d9eac5e8455f8ed1df41b7141e0612d53a
|
3 |
+
size 396481
|
experiments/transformer_iter_160000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:45316082fd7359864cb08bac63b7c8f278f881763ae6fd97416610c2cdca4e67
|
3 |
+
size 127208897
|
experiments/vgg_normalised.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:804ca2835ecf7539f0cd2a7ac3c18ce81e6f8468969ae7117ac0c148d286bb4a
|
3 |
+
size 80102481
|
function.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def calc_mean_std(feat, eps=1e-5):
|
5 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
6 |
+
size = feat.size()
|
7 |
+
assert (len(size) == 4)
|
8 |
+
N, C = size[:2]
|
9 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
10 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
11 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
12 |
+
return feat_mean, feat_std
|
13 |
+
|
14 |
+
def calc_mean_std1(feat, eps=1e-5):
|
15 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
16 |
+
size = feat.size()
|
17 |
+
# assert (len(size) == 4)
|
18 |
+
WH,N, C = size
|
19 |
+
feat_var = feat.var(dim=0) + eps
|
20 |
+
feat_std = feat_var.sqrt()
|
21 |
+
feat_mean = feat.mean(dim=0)
|
22 |
+
return feat_mean, feat_std
|
23 |
+
def normal(feat, eps=1e-5):
|
24 |
+
feat_mean, feat_std= calc_mean_std(feat, eps)
|
25 |
+
normalized=(feat-feat_mean)/feat_std
|
26 |
+
return normalized
|
27 |
+
def normal_style(feat, eps=1e-5):
|
28 |
+
feat_mean, feat_std= calc_mean_std1(feat, eps)
|
29 |
+
normalized=(feat-feat_mean)/feat_std
|
30 |
+
return normalized
|
31 |
+
|
32 |
+
def _calc_feat_flatten_mean_std(feat):
|
33 |
+
# takes 3D feat (C, H, W), return mean and std of array within channels
|
34 |
+
assert (feat.size()[0] == 3)
|
35 |
+
assert (isinstance(feat, torch.FloatTensor))
|
36 |
+
feat_flatten = feat.view(3, -1)
|
37 |
+
mean = feat_flatten.mean(dim=-1, keepdim=True)
|
38 |
+
std = feat_flatten.std(dim=-1, keepdim=True)
|
39 |
+
return feat_flatten, mean, std
|
40 |
+
|
41 |
+
|
42 |
+
def _mat_sqrt(x):
|
43 |
+
U, D, V = torch.svd(x)
|
44 |
+
return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
|
45 |
+
|
46 |
+
|
47 |
+
def coral(source, target):
|
48 |
+
# assume both source and target are 3D array (C, H, W)
|
49 |
+
# Note: flatten -> f
|
50 |
+
|
51 |
+
source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
|
52 |
+
source_f_norm = (source_f - source_f_mean.expand_as(
|
53 |
+
source_f)) / source_f_std.expand_as(source_f)
|
54 |
+
source_f_cov_eye = \
|
55 |
+
torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
|
56 |
+
|
57 |
+
target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
|
58 |
+
target_f_norm = (target_f - target_f_mean.expand_as(
|
59 |
+
target_f)) / target_f_std.expand_as(target_f)
|
60 |
+
target_f_cov_eye = \
|
61 |
+
torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
|
62 |
+
|
63 |
+
source_f_norm_transfer = torch.mm(
|
64 |
+
_mat_sqrt(target_f_cov_eye),
|
65 |
+
torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
|
66 |
+
source_f_norm)
|
67 |
+
)
|
68 |
+
|
69 |
+
source_f_transfer = source_f_norm_transfer * \
|
70 |
+
target_f_std.expand_as(source_f_norm) + \
|
71 |
+
target_f_mean.expand_as(source_f_norm)
|
72 |
+
|
73 |
+
return source_f_transfer.view(source.size())
|
images/images_image1.jpg
ADDED
![]() |
logs/events.out.tfevents.1717260856.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:561832434af350a6f0e0460f79365352c81825418a2f61f43edf7c5224c63c18
|
3 |
+
size 40
|
logs/events.out.tfevents.1717261358.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f8684900a20b584701c30b530b9275b9918bf2f040446e95aee1bf7ff0a505b
|
3 |
+
size 40
|
logs/events.out.tfevents.1717261528.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab1d65bcd8f36cf6678e8c25ea6ff3fa07db539ca9a0b59c2437bf020b9e5502
|
3 |
+
size 40
|
logs/events.out.tfevents.1717262282.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a0cc507c6909cfb2822bc3be0832c958efbd42f6f07317f1a613118593123ed
|
3 |
+
size 40
|
logs/events.out.tfevents.1717262831.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:727528a0119df4f847c66c0acc34d64f96a57a9bb8dda57bafc4a68070157685
|
3 |
+
size 40
|
logs/events.out.tfevents.1717262870.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2042827551df19eb1bd516c6e0a31acea34ae0c427d7f233064897db9d7de7c2
|
3 |
+
size 40
|
logs/events.out.tfevents.1717300182.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac2eda70720342dc7dde77eff301a8c653df2a2ad9f85d8311bb9112d9cbff0b
|
3 |
+
size 40
|
logs/events.out.tfevents.1717300226.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa75376b2ac29e81ef7a15af8d07ef6921dfbc07107f24d98daf195d675249d2
|
3 |
+
size 40
|
logs/events.out.tfevents.1717300229.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c4226559cd8c53df827cbeb02f96f80d60058d1f7cb89f8750462901b2676924
|
3 |
+
size 40
|
logs/events.out.tfevents.1717300284.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f17141f226ef2ba638e4714de753ecae20ce00b091f5f1c784ccb1247464264f
|
3 |
+
size 790
|
logs/events.out.tfevents.1717300608.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba8bbc1a99b7c86c3fad38911ee90e762ab94cfd6d3601dabc2d3472d002fe81
|
3 |
+
size 40
|
logs/events.out.tfevents.1717300611.NGUYENTIENDAT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5797aba72e948834e99c14dd5c467030a7ba4cd979b8abf381132391da8e946c
|
3 |
+
size 40
|
models/StyTR.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
import numpy as np
|
5 |
+
from util import box_ops
|
6 |
+
from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
|
7 |
+
accuracy, get_world_size, interpolate,
|
8 |
+
is_dist_avail_and_initialized)
|
9 |
+
from function import normal,normal_style
|
10 |
+
from function import calc_mean_std
|
11 |
+
import scipy.stats as stats
|
12 |
+
from models.ViT_helper import DropPath, to_2tuple, trunc_normal_
|
13 |
+
|
14 |
+
class PatchEmbed(nn.Module):
|
15 |
+
""" Image to Patch Embedding
|
16 |
+
"""
|
17 |
+
def __init__(self, img_size=256, patch_size=8, in_chans=3, embed_dim=512):
|
18 |
+
super().__init__()
|
19 |
+
img_size = to_2tuple(img_size)
|
20 |
+
patch_size = to_2tuple(patch_size)
|
21 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
22 |
+
self.img_size = img_size
|
23 |
+
self.patch_size = patch_size
|
24 |
+
self.num_patches = num_patches
|
25 |
+
|
26 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
27 |
+
self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
B, C, H, W = x.shape
|
31 |
+
print(f"PatchEmbed Input: {x.shape}")
|
32 |
+
x = self.proj(x)
|
33 |
+
print(f"PatchEmbed Output: {x.shape}")
|
34 |
+
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
decoder = nn.Sequential(
|
39 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
40 |
+
nn.Conv2d(512, 256, (3, 3)),
|
41 |
+
nn.ReLU(),
|
42 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
43 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
44 |
+
nn.Conv2d(256, 256, (3, 3)),
|
45 |
+
nn.ReLU(),
|
46 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
47 |
+
nn.Conv2d(256, 256, (3, 3)),
|
48 |
+
nn.ReLU(),
|
49 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
50 |
+
nn.Conv2d(256, 256, (3, 3)),
|
51 |
+
nn.ReLU(),
|
52 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
53 |
+
nn.Conv2d(256, 128, (3, 3)),
|
54 |
+
nn.ReLU(),
|
55 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
56 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
57 |
+
nn.Conv2d(128, 128, (3, 3)),
|
58 |
+
nn.ReLU(),
|
59 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
60 |
+
nn.Conv2d(128, 64, (3, 3)),
|
61 |
+
nn.ReLU(),
|
62 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
63 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
64 |
+
nn.Conv2d(64, 64, (3, 3)),
|
65 |
+
nn.ReLU(),
|
66 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
67 |
+
nn.Conv2d(64, 3, (3, 3)),
|
68 |
+
)
|
69 |
+
################# IN SHAPE CUA MODEL
|
70 |
+
for name, module in decoder.named_children():
|
71 |
+
def hook(module, input, output):
|
72 |
+
print(f"{module.__class__.__name__} Input: {input[0].shape}")
|
73 |
+
print(f"{module.__class__.__name__} Output: {output.shape}")
|
74 |
+
module.register_forward_hook(hook)
|
75 |
+
|
76 |
+
|
77 |
+
vgg = nn.Sequential(
|
78 |
+
nn.Conv2d(3, 3, (1, 1)),
|
79 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
80 |
+
nn.Conv2d(3, 64, (3, 3)),
|
81 |
+
nn.ReLU(), # relu1-1
|
82 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
83 |
+
nn.Conv2d(64, 64, (3, 3)),
|
84 |
+
nn.ReLU(), # relu1-2
|
85 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
86 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
87 |
+
nn.Conv2d(64, 128, (3, 3)),
|
88 |
+
nn.ReLU(), # relu2-1
|
89 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
90 |
+
nn.Conv2d(128, 128, (3, 3)),
|
91 |
+
nn.ReLU(), # relu2-2
|
92 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
93 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
94 |
+
nn.Conv2d(128, 256, (3, 3)),
|
95 |
+
nn.ReLU(), # relu3-1
|
96 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
97 |
+
nn.Conv2d(256, 256, (3, 3)),
|
98 |
+
nn.ReLU(), # relu3-2
|
99 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
100 |
+
nn.Conv2d(256, 256, (3, 3)),
|
101 |
+
nn.ReLU(), # relu3-3
|
102 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
103 |
+
nn.Conv2d(256, 256, (3, 3)),
|
104 |
+
nn.ReLU(), # relu3-4
|
105 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
106 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
107 |
+
nn.Conv2d(256, 512, (3, 3)),
|
108 |
+
nn.ReLU(), # relu4-1, this is the last layer used
|
109 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
110 |
+
nn.Conv2d(512, 512, (3, 3)),
|
111 |
+
nn.ReLU(), # relu4-2
|
112 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
113 |
+
nn.Conv2d(512, 512, (3, 3)),
|
114 |
+
nn.ReLU(), # relu4-3
|
115 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
116 |
+
nn.Conv2d(512, 512, (3, 3)),
|
117 |
+
nn.ReLU(), # relu4-4
|
118 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
119 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
120 |
+
nn.Conv2d(512, 512, (3, 3)),
|
121 |
+
nn.ReLU(), # relu5-1
|
122 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
123 |
+
nn.Conv2d(512, 512, (3, 3)),
|
124 |
+
nn.ReLU(), # relu5-2
|
125 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
126 |
+
nn.Conv2d(512, 512, (3, 3)),
|
127 |
+
nn.ReLU(), # relu5-3
|
128 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
129 |
+
nn.Conv2d(512, 512, (3, 3)),
|
130 |
+
nn.ReLU() # relu5-4
|
131 |
+
)
|
132 |
+
|
133 |
+
################# IN SHAPE CUA MODEL DECODER
|
134 |
+
for name, module in vgg.named_children():
|
135 |
+
def hook(module, input, output):
|
136 |
+
print(f"{module.__class__.__name__} Input: {input[0].shape}")
|
137 |
+
print(f"{module.__class__.__name__} Output: {output.shape}")
|
138 |
+
module.register_forward_hook(hook)
|
139 |
+
|
140 |
+
class MLP(nn.Module):
|
141 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
142 |
+
|
143 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
144 |
+
super().__init__()
|
145 |
+
self.num_layers = num_layers
|
146 |
+
h = [hidden_dim] * (num_layers - 1)
|
147 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
for i, layer in enumerate(self.layers):
|
151 |
+
print(f"MLP Layer {i} Input: {x.shape}")
|
152 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
153 |
+
print(f"MLP Layer {i} Output: {x.shape}")
|
154 |
+
return x
|
155 |
+
class StyTrans(nn.Module):
|
156 |
+
""" This is the style transform transformer module """
|
157 |
+
|
158 |
+
def __init__(self,encoder,decoder,PatchEmbed, transformer,args):
|
159 |
+
|
160 |
+
super().__init__()
|
161 |
+
enc_layers = list(encoder.children())
|
162 |
+
self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
|
163 |
+
self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
|
164 |
+
self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
|
165 |
+
self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
|
166 |
+
self.enc_5 = nn.Sequential(*enc_layers[31:44]) # relu4_1 -> relu5_1
|
167 |
+
|
168 |
+
for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:
|
169 |
+
for param in getattr(self, name).parameters():
|
170 |
+
param.requires_grad = False
|
171 |
+
|
172 |
+
self.mse_loss = nn.MSELoss()
|
173 |
+
self.transformer = transformer
|
174 |
+
hidden_dim = transformer.d_model
|
175 |
+
self.decode = decoder
|
176 |
+
self.embedding = PatchEmbed
|
177 |
+
|
178 |
+
def encode_with_intermediate(self, input):
|
179 |
+
results = [input]
|
180 |
+
for i in range(5):
|
181 |
+
func = getattr(self, 'enc_{:d}'.format(i + 1))
|
182 |
+
results.append(func(results[-1]))
|
183 |
+
return results[1:]
|
184 |
+
|
185 |
+
def calc_content_loss(self, input, target):
|
186 |
+
assert (input.size() == target.size())
|
187 |
+
assert (target.requires_grad is False)
|
188 |
+
return self.mse_loss(input, target)
|
189 |
+
|
190 |
+
def calc_style_loss(self, input, target):
|
191 |
+
assert (input.size() == target.size())
|
192 |
+
assert (target.requires_grad is False)
|
193 |
+
input_mean, input_std = calc_mean_std(input)
|
194 |
+
target_mean, target_std = calc_mean_std(target)
|
195 |
+
return self.mse_loss(input_mean, target_mean) + \
|
196 |
+
self.mse_loss(input_std, target_std)
|
197 |
+
def forward(self, samples_c: NestedTensor,samples_s: NestedTensor):
|
198 |
+
""" The forward expects a NestedTensor, which consists of:
|
199 |
+
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
|
200 |
+
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
|
201 |
+
|
202 |
+
"""
|
203 |
+
content_input = samples_c
|
204 |
+
style_input = samples_s
|
205 |
+
if isinstance(samples_c, (list, torch.Tensor)):
|
206 |
+
samples_c = nested_tensor_from_tensor_list(samples_c) # support different-sized images padding is used for mask [tensor, mask]
|
207 |
+
if isinstance(samples_s, (list, torch.Tensor)):
|
208 |
+
samples_s = nested_tensor_from_tensor_list(samples_s)
|
209 |
+
|
210 |
+
# ### features used to calcate loss
|
211 |
+
content_feats = self.encode_with_intermediate(samples_c.tensors)
|
212 |
+
style_feats = self.encode_with_intermediate(samples_s.tensors)
|
213 |
+
|
214 |
+
### Linear projection
|
215 |
+
print(f"Embedding Content Input: {samples_c.tensors.shape}")
|
216 |
+
style = self.embedding(samples_s.tensors)
|
217 |
+
print(f"Style Output: {style.shape}")
|
218 |
+
content = self.embedding(samples_c.tensors)
|
219 |
+
print(f"Embedding Content Output: {content.shape}")
|
220 |
+
|
221 |
+
# postional embedding is calculated in transformer.py
|
222 |
+
pos_s = None
|
223 |
+
pos_c = None
|
224 |
+
|
225 |
+
mask = None
|
226 |
+
hs = self.transformer(style, mask , content, pos_c, pos_s)
|
227 |
+
Ics = self.decode(hs)
|
228 |
+
|
229 |
+
Ics_feats = self.encode_with_intermediate(Ics)
|
230 |
+
loss_c = self.calc_content_loss(normal(Ics_feats[-1]), normal(content_feats[-1]))+self.calc_content_loss(normal(Ics_feats[-2]), normal(content_feats[-2]))
|
231 |
+
# Style loss
|
232 |
+
loss_s = self.calc_style_loss(Ics_feats[0], style_feats[0])
|
233 |
+
for i in range(1, 5):
|
234 |
+
loss_s += self.calc_style_loss(Ics_feats[i], style_feats[i])
|
235 |
+
|
236 |
+
|
237 |
+
Icc = self.decode(self.transformer(content, mask , content, pos_c, pos_c))
|
238 |
+
Iss = self.decode(self.transformer(style, mask , style, pos_s, pos_s))
|
239 |
+
|
240 |
+
#Identity losses lambda 1
|
241 |
+
loss_lambda1 = self.calc_content_loss(Icc,content_input)+self.calc_content_loss(Iss,style_input)
|
242 |
+
|
243 |
+
#Identity losses lambda 2
|
244 |
+
Icc_feats=self.encode_with_intermediate(Icc)
|
245 |
+
Iss_feats=self.encode_with_intermediate(Iss)
|
246 |
+
loss_lambda2 = self.calc_content_loss(Icc_feats[0], content_feats[0])+self.calc_content_loss(Iss_feats[0], style_feats[0])
|
247 |
+
for i in range(1, 5):
|
248 |
+
loss_lambda2 += self.calc_content_loss(Icc_feats[i], content_feats[i])+self.calc_content_loss(Iss_feats[i], style_feats[i])
|
249 |
+
# Please select and comment out one of the following two sentences
|
250 |
+
return Ics, loss_c, loss_s, loss_lambda1, loss_lambda2 #train
|
251 |
+
# return Ics #test
|
models/ViT_helper.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
5 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
6 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
7 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
8 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
9 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
10 |
+
'survival rate' as the argument.
|
11 |
+
"""
|
12 |
+
if drop_prob == 0. or not training:
|
13 |
+
return x
|
14 |
+
keep_prob = 1 - drop_prob
|
15 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
16 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
17 |
+
random_tensor.floor_() # binarize
|
18 |
+
output = x.div(keep_prob) * random_tensor
|
19 |
+
return output
|
20 |
+
|
21 |
+
|
22 |
+
class DropPath(nn.Module):
|
23 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
24 |
+
"""
|
25 |
+
def __init__(self, drop_prob=None):
|
26 |
+
super(DropPath, self).__init__()
|
27 |
+
self.drop_prob = drop_prob
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
return drop_path(x, self.drop_prob, self.training)
|
31 |
+
|
32 |
+
from itertools import repeat
|
33 |
+
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
34 |
+
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
35 |
+
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
|
36 |
+
from torch._six import container_abcs,int_classes
|
37 |
+
else:
|
38 |
+
import collections.abc as container_abcs
|
39 |
+
int_classes = int
|
40 |
+
|
41 |
+
|
42 |
+
# From PyTorch internals
|
43 |
+
def _ntuple(n):
|
44 |
+
def parse(x):
|
45 |
+
if isinstance(x, container_abcs.Iterable):
|
46 |
+
return x
|
47 |
+
return tuple(repeat(x, n))
|
48 |
+
return parse
|
49 |
+
|
50 |
+
|
51 |
+
to_1tuple = _ntuple(1)
|
52 |
+
to_2tuple = _ntuple(2)
|
53 |
+
to_3tuple = _ntuple(3)
|
54 |
+
to_4tuple = _ntuple(4)
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
import torch
|
59 |
+
import math
|
60 |
+
import warnings
|
61 |
+
|
62 |
+
|
63 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
64 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
65 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
66 |
+
def norm_cdf(x):
|
67 |
+
# Computes standard normal cumulative distribution function
|
68 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
69 |
+
|
70 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
71 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
72 |
+
"The distribution of values may be incorrect.",
|
73 |
+
stacklevel=2)
|
74 |
+
|
75 |
+
with torch.no_grad():
|
76 |
+
# Values are generated by using a truncated uniform distribution and
|
77 |
+
# then using the inverse CDF for the normal distribution.
|
78 |
+
# Get upper and lower cdf values
|
79 |
+
l = norm_cdf((a - mean) / std)
|
80 |
+
u = norm_cdf((b - mean) / std)
|
81 |
+
|
82 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
83 |
+
# [2l-1, 2u-1].
|
84 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
85 |
+
|
86 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
87 |
+
# standard normal
|
88 |
+
tensor.erfinv_()
|
89 |
+
|
90 |
+
# Transform to proper mean, std
|
91 |
+
tensor.mul_(std * math.sqrt(2.))
|
92 |
+
tensor.add_(mean)
|
93 |
+
|
94 |
+
# Clamp to ensure it's in the proper range
|
95 |
+
tensor.clamp_(min=a, max=b)
|
96 |
+
return tensor
|
97 |
+
|
98 |
+
|
99 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
100 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
101 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
102 |
+
normal distribution. The values are effectively drawn from the
|
103 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
104 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
105 |
+
the bounds. The method used for generating the random values works
|
106 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
107 |
+
Args:
|
108 |
+
tensor: an n-dimensional `torch.Tensor`
|
109 |
+
mean: the mean of the normal distribution
|
110 |
+
std: the standard deviation of the normal distribution
|
111 |
+
a: the minimum cutoff value
|
112 |
+
b: the maximum cutoff value
|
113 |
+
Examples:
|
114 |
+
>>> w = torch.empty(3, 5)
|
115 |
+
>>> nn.init.trunc_normal_(w)
|
116 |
+
"""
|
117 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
models/__pycache__/StyTR.cpython-310.pyc
ADDED
Binary file (7.41 kB). View file
|
|
models/__pycache__/ViT_helper.cpython-310.pyc
ADDED
Binary file (4.21 kB). View file
|
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (150 Bytes). View file
|
|
models/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (9.43 kB). View file
|
|
models/sampler.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch.utils import data
|
3 |
+
|
4 |
+
|
5 |
+
def InfiniteSampler(n):
|
6 |
+
# i = 0
|
7 |
+
i = n - 1
|
8 |
+
order = np.random.permutation(n)
|
9 |
+
while True:
|
10 |
+
yield order[i]
|
11 |
+
i += 1
|
12 |
+
if i >= n:
|
13 |
+
np.random.seed()
|
14 |
+
order = np.random.permutation(n)
|
15 |
+
i = 0
|
16 |
+
|
17 |
+
|
18 |
+
class InfiniteSamplerWrapper(data.sampler.Sampler):
|
19 |
+
def __init__(self, data_source):
|
20 |
+
self.num_samples = len(data_source)
|
21 |
+
|
22 |
+
def __iter__(self):
|
23 |
+
return iter(InfiniteSampler(self.num_samples))
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return 2 ** 31
|
models/transformer.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Optional, List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn, Tensor
|
7 |
+
from function import normal,normal_style
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
12 |
+
class Transformer(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=3,
|
15 |
+
num_decoder_layers=3, dim_feedforward=2048, dropout=0.1,
|
16 |
+
activation="relu", normalize_before=False,
|
17 |
+
return_intermediate_dec=False):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
21 |
+
dropout, activation, normalize_before)
|
22 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
23 |
+
self.encoder_c = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
24 |
+
self.encoder_s = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
25 |
+
|
26 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
27 |
+
dropout, activation, normalize_before)
|
28 |
+
decoder_norm = nn.LayerNorm(d_model)
|
29 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
30 |
+
return_intermediate=return_intermediate_dec)
|
31 |
+
|
32 |
+
self._reset_parameters()
|
33 |
+
|
34 |
+
self.d_model = d_model
|
35 |
+
self.nhead = nhead
|
36 |
+
|
37 |
+
self.new_ps = nn.Conv2d(512 , 512 , (1,1))
|
38 |
+
self.averagepooling = nn.AdaptiveAvgPool2d(18)
|
39 |
+
|
40 |
+
def _reset_parameters(self):
|
41 |
+
for p in self.parameters():
|
42 |
+
if p.dim() > 1:
|
43 |
+
nn.init.xavier_uniform_(p)
|
44 |
+
|
45 |
+
def forward(self, style, mask , content, pos_embed_c, pos_embed_s):
|
46 |
+
|
47 |
+
# content-aware positional embedding
|
48 |
+
content_pool = self.averagepooling(content)
|
49 |
+
pos_c = self.new_ps(content_pool)
|
50 |
+
pos_embed_c = F.interpolate(pos_c, mode='bilinear',size= style.shape[-2:])
|
51 |
+
|
52 |
+
###flatten NxCxHxW to HWxNxC
|
53 |
+
style = style.flatten(2).permute(2, 0, 1)
|
54 |
+
if pos_embed_s is not None:
|
55 |
+
pos_embed_s = pos_embed_s.flatten(2).permute(2, 0, 1)
|
56 |
+
|
57 |
+
content = content.flatten(2).permute(2, 0, 1)
|
58 |
+
if pos_embed_c is not None:
|
59 |
+
pos_embed_c = pos_embed_c.flatten(2).permute(2, 0, 1)
|
60 |
+
|
61 |
+
|
62 |
+
style = self.encoder_s(style, src_key_padding_mask=mask, pos=pos_embed_s)
|
63 |
+
content = self.encoder_c(content, src_key_padding_mask=mask, pos=pos_embed_c)
|
64 |
+
hs = self.decoder(content, style, memory_key_padding_mask=mask,
|
65 |
+
pos=pos_embed_s, query_pos=pos_embed_c)[0]
|
66 |
+
|
67 |
+
### HWxNxC to NxCxHxW to
|
68 |
+
N, B, C= hs.shape
|
69 |
+
H = int(np.sqrt(N))
|
70 |
+
hs = hs.permute(1, 2, 0)
|
71 |
+
hs = hs.view(B, C, -1,H)
|
72 |
+
|
73 |
+
return hs
|
74 |
+
|
75 |
+
|
76 |
+
class TransformerEncoder(nn.Module):
|
77 |
+
|
78 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
79 |
+
super().__init__()
|
80 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
81 |
+
self.num_layers = num_layers
|
82 |
+
self.norm = norm
|
83 |
+
|
84 |
+
def forward(self, src,
|
85 |
+
mask: Optional[Tensor] = None,
|
86 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
87 |
+
pos: Optional[Tensor] = None):
|
88 |
+
output = src
|
89 |
+
|
90 |
+
for layer in self.layers:
|
91 |
+
output = layer(output, src_mask=mask,
|
92 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
93 |
+
|
94 |
+
if self.norm is not None:
|
95 |
+
output = self.norm(output)
|
96 |
+
|
97 |
+
return output
|
98 |
+
|
99 |
+
|
100 |
+
class TransformerDecoder(nn.Module):
|
101 |
+
|
102 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
103 |
+
super().__init__()
|
104 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
105 |
+
self.num_layers = num_layers
|
106 |
+
self.norm = norm
|
107 |
+
self.return_intermediate = return_intermediate
|
108 |
+
|
109 |
+
def forward(self, tgt, memory,
|
110 |
+
tgt_mask: Optional[Tensor] = None,
|
111 |
+
memory_mask: Optional[Tensor] = None,
|
112 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
113 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
114 |
+
pos: Optional[Tensor] = None,
|
115 |
+
query_pos: Optional[Tensor] = None):
|
116 |
+
output = tgt
|
117 |
+
|
118 |
+
intermediate = []
|
119 |
+
|
120 |
+
for layer in self.layers:
|
121 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
122 |
+
memory_mask=memory_mask,
|
123 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
124 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
125 |
+
pos=pos, query_pos=query_pos)
|
126 |
+
if self.return_intermediate:
|
127 |
+
intermediate.append(self.norm(output))
|
128 |
+
|
129 |
+
if self.norm is not None:
|
130 |
+
output = self.norm(output)
|
131 |
+
if self.return_intermediate:
|
132 |
+
intermediate.pop()
|
133 |
+
intermediate.append(output)
|
134 |
+
|
135 |
+
if self.return_intermediate:
|
136 |
+
return torch.stack(intermediate)
|
137 |
+
|
138 |
+
return output.unsqueeze(0)
|
139 |
+
|
140 |
+
|
141 |
+
class TransformerEncoderLayer(nn.Module):
|
142 |
+
|
143 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
144 |
+
activation="relu", normalize_before=False):
|
145 |
+
super().__init__()
|
146 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
147 |
+
# Implementation of Feedforward model
|
148 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
149 |
+
self.dropout = nn.Dropout(dropout)
|
150 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
151 |
+
|
152 |
+
self.norm1 = nn.LayerNorm(d_model)
|
153 |
+
self.norm2 = nn.LayerNorm(d_model)
|
154 |
+
self.dropout1 = nn.Dropout(dropout)
|
155 |
+
self.dropout2 = nn.Dropout(dropout)
|
156 |
+
|
157 |
+
self.activation = _get_activation_fn(activation)
|
158 |
+
self.normalize_before = normalize_before
|
159 |
+
|
160 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
161 |
+
return tensor if pos is None else tensor + pos
|
162 |
+
|
163 |
+
def forward_post(self,
|
164 |
+
src,
|
165 |
+
src_mask: Optional[Tensor] = None,
|
166 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
167 |
+
pos: Optional[Tensor] = None):
|
168 |
+
q = k = self.with_pos_embed(src, pos)
|
169 |
+
# q = k = src
|
170 |
+
# print(q.size(),k.size(),src.size())
|
171 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
172 |
+
key_padding_mask=src_key_padding_mask)[0]
|
173 |
+
src = src + self.dropout1(src2)
|
174 |
+
src = self.norm1(src)
|
175 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
176 |
+
src = src + self.dropout2(src2)
|
177 |
+
src = self.norm2(src)
|
178 |
+
return src
|
179 |
+
|
180 |
+
def forward_pre(self, src,
|
181 |
+
src_mask: Optional[Tensor] = None,
|
182 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
183 |
+
pos: Optional[Tensor] = None):
|
184 |
+
src2 = self.norm1(src)
|
185 |
+
q = k = self.with_pos_embed(src2, pos)
|
186 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
187 |
+
key_padding_mask=src_key_padding_mask)[0]
|
188 |
+
src = src + self.dropout1(src2)
|
189 |
+
src2 = self.norm2(src)
|
190 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
191 |
+
src = src + self.dropout2(src2)
|
192 |
+
return src
|
193 |
+
|
194 |
+
def forward(self, src,
|
195 |
+
src_mask: Optional[Tensor] = None,
|
196 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
197 |
+
pos: Optional[Tensor] = None):
|
198 |
+
if self.normalize_before:
|
199 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
200 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
201 |
+
|
202 |
+
|
203 |
+
class TransformerDecoderLayer(nn.Module):
|
204 |
+
|
205 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
206 |
+
activation="relu", normalize_before=False):
|
207 |
+
super().__init__()
|
208 |
+
# d_model embedding dim
|
209 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
210 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
211 |
+
# Implementation of Feedforward model
|
212 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
213 |
+
self.dropout = nn.Dropout(dropout)
|
214 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
215 |
+
|
216 |
+
self.norm1 = nn.LayerNorm(d_model)
|
217 |
+
self.norm2 = nn.LayerNorm(d_model)
|
218 |
+
self.norm3 = nn.LayerNorm(d_model)
|
219 |
+
self.dropout1 = nn.Dropout(dropout)
|
220 |
+
self.dropout2 = nn.Dropout(dropout)
|
221 |
+
self.dropout3 = nn.Dropout(dropout)
|
222 |
+
|
223 |
+
self.activation = _get_activation_fn(activation)
|
224 |
+
self.normalize_before = normalize_before
|
225 |
+
|
226 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
227 |
+
return tensor if pos is None else tensor + pos
|
228 |
+
|
229 |
+
def forward_post(self, tgt, memory,
|
230 |
+
tgt_mask: Optional[Tensor] = None,
|
231 |
+
memory_mask: Optional[Tensor] = None,
|
232 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
233 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
234 |
+
pos: Optional[Tensor] = None,
|
235 |
+
query_pos: Optional[Tensor] = None):
|
236 |
+
|
237 |
+
|
238 |
+
q = self.with_pos_embed(tgt, query_pos)
|
239 |
+
k = self.with_pos_embed(memory, pos)
|
240 |
+
v = memory
|
241 |
+
|
242 |
+
tgt2 = self.self_attn(q, k, v, attn_mask=tgt_mask,
|
243 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
244 |
+
|
245 |
+
tgt = tgt + self.dropout1(tgt2)
|
246 |
+
tgt = self.norm1(tgt)
|
247 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
248 |
+
key=self.with_pos_embed(memory, pos),
|
249 |
+
value=memory, attn_mask=memory_mask,
|
250 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
251 |
+
tgt = tgt + self.dropout2(tgt2)
|
252 |
+
tgt = self.norm2(tgt)
|
253 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
254 |
+
tgt = tgt + self.dropout3(tgt2)
|
255 |
+
tgt = self.norm3(tgt)
|
256 |
+
return tgt
|
257 |
+
|
258 |
+
def forward_pre(self, tgt, memory,
|
259 |
+
tgt_mask: Optional[Tensor] = None,
|
260 |
+
memory_mask: Optional[Tensor] = None,
|
261 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
262 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
263 |
+
pos: Optional[Tensor] = None,
|
264 |
+
query_pos: Optional[Tensor] = None):
|
265 |
+
tgt2 = self.norm1(tgt)
|
266 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
267 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
268 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
269 |
+
|
270 |
+
tgt = tgt + self.dropout1(tgt2)
|
271 |
+
tgt2 = self.norm2(tgt)
|
272 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
273 |
+
key=self.with_pos_embed(memory, pos),
|
274 |
+
value=memory, attn_mask=memory_mask,
|
275 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
276 |
+
|
277 |
+
tgt = tgt + self.dropout2(tgt2)
|
278 |
+
tgt2 = self.norm3(tgt)
|
279 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
280 |
+
tgt = tgt + self.dropout3(tgt2)
|
281 |
+
return tgt
|
282 |
+
|
283 |
+
def forward(self, tgt, memory,
|
284 |
+
tgt_mask: Optional[Tensor] = None,
|
285 |
+
memory_mask: Optional[Tensor] = None,
|
286 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
287 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
288 |
+
pos: Optional[Tensor] = None,
|
289 |
+
query_pos: Optional[Tensor] = None):
|
290 |
+
if self.normalize_before:
|
291 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
292 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
293 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
294 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
295 |
+
|
296 |
+
|
297 |
+
def _get_clones(module, N):
|
298 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
299 |
+
|
300 |
+
|
301 |
+
def build_transformer(args):
|
302 |
+
return Transformer(
|
303 |
+
d_model=args.hidden_dim,
|
304 |
+
dropout=args.dropout,
|
305 |
+
nhead=args.nheads,
|
306 |
+
dim_feedforward=args.dim_feedforward,
|
307 |
+
num_encoder_layers=args.enc_layers,
|
308 |
+
num_decoder_layers=args.dec_layers,
|
309 |
+
normalize_before=args.pre_norm,
|
310 |
+
return_intermediate_dec=True,
|
311 |
+
)
|
312 |
+
|
313 |
+
|
314 |
+
def _get_activation_fn(activation):
|
315 |
+
"""Return an activation function given a string"""
|
316 |
+
if activation == "relu":
|
317 |
+
return F.relu
|
318 |
+
if activation == "gelu":
|
319 |
+
return F.gelu
|
320 |
+
if activation == "glu":
|
321 |
+
return F.glu
|
322 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
outputs/test/0.jpg
ADDED
![]() |
sampler.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch.utils import data
|
3 |
+
|
4 |
+
|
5 |
+
def InfiniteSampler(n):
|
6 |
+
# i = 0
|
7 |
+
i = n - 1
|
8 |
+
order = np.random.permutation(n)
|
9 |
+
while True:
|
10 |
+
yield order[i]
|
11 |
+
i += 1
|
12 |
+
if i >= n:
|
13 |
+
np.random.seed()
|
14 |
+
order = np.random.permutation(n)
|
15 |
+
i = 0
|
16 |
+
|
17 |
+
|
18 |
+
class InfiniteSamplerWrapper(data.sampler.Sampler):
|
19 |
+
def __init__(self, data_source):
|
20 |
+
self.num_samples = len(data_source)
|
21 |
+
|
22 |
+
def __iter__(self):
|
23 |
+
return iter(InfiniteSampler(self.num_samples))
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return 2 ** 31
|
style/style_image2.jpg
ADDED
![]() |
test.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from PIL import Image
|
7 |
+
from os.path import basename
|
8 |
+
from os.path import splitext
|
9 |
+
from torchvision import transforms
|
10 |
+
from torchvision.utils import save_image
|
11 |
+
from function import calc_mean_std, normal, coral
|
12 |
+
import models.transformer as transformer
|
13 |
+
import models.StyTR as StyTR
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
from matplotlib import cm
|
16 |
+
from function import normal
|
17 |
+
import numpy as np
|
18 |
+
import time
|
19 |
+
def test_transform(size, crop):
|
20 |
+
transform_list = []
|
21 |
+
|
22 |
+
if size != 0:
|
23 |
+
transform_list.append(transforms.Resize(size))
|
24 |
+
if crop:
|
25 |
+
transform_list.append(transforms.CenterCrop(size))
|
26 |
+
transform_list.append(transforms.ToTensor())
|
27 |
+
transform = transforms.Compose(transform_list)
|
28 |
+
return transform
|
29 |
+
def style_transform(h,w):
|
30 |
+
k = (h,w)
|
31 |
+
size = int(np.max(k))
|
32 |
+
print(type(size))
|
33 |
+
transform_list = []
|
34 |
+
transform_list.append(transforms.CenterCrop((h,w)))
|
35 |
+
transform_list.append(transforms.ToTensor())
|
36 |
+
transform = transforms.Compose(transform_list)
|
37 |
+
return transform
|
38 |
+
|
39 |
+
def content_transform():
|
40 |
+
|
41 |
+
transform_list = []
|
42 |
+
transform_list.append(transforms.ToTensor())
|
43 |
+
transform = transforms.Compose(transform_list)
|
44 |
+
return transform
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
parser = argparse.ArgumentParser()
|
49 |
+
# Basic options
|
50 |
+
parser.add_argument('--content', type=str,
|
51 |
+
help='File path to the content image')
|
52 |
+
parser.add_argument('--content_dir', type=str,
|
53 |
+
help='Directory path to a batch of content images')
|
54 |
+
parser.add_argument('--style', type=str,
|
55 |
+
help='File path to the style image, or multiple style \
|
56 |
+
images separated by commas if you want to do style \
|
57 |
+
interpolation or spatial control')
|
58 |
+
parser.add_argument('--style_dir', type=str,
|
59 |
+
help='Directory path to a batch of style images')
|
60 |
+
parser.add_argument('--output', type=str, default='output',
|
61 |
+
help='Directory to save the output image(s)')
|
62 |
+
parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth')
|
63 |
+
parser.add_argument('--decoder_path', type=str, default='experiments/decoder_iter_160000.pth')
|
64 |
+
parser.add_argument('--Trans_path', type=str, default='experiments/transformer_iter_160000.pth')
|
65 |
+
parser.add_argument('--embedding_path', type=str, default='experiments/embedding_iter_160000.pth')
|
66 |
+
|
67 |
+
|
68 |
+
parser.add_argument('--style_interpolation_weights', type=str, default="")
|
69 |
+
parser.add_argument('--a', type=float, default=1.0)
|
70 |
+
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
71 |
+
help="Type of positional embedding to use on top of the image features")
|
72 |
+
parser.add_argument('--hidden_dim', default=512, type=int,
|
73 |
+
help="Size of the embeddings (dimension of the transformer)")
|
74 |
+
args = parser.parse_args()
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
# Advanced options
|
80 |
+
content_size=512
|
81 |
+
style_size=512
|
82 |
+
crop='store_true'
|
83 |
+
save_ext='.jpg'
|
84 |
+
output_path=args.output
|
85 |
+
preserve_color='store_true'
|
86 |
+
alpha=args.a
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
92 |
+
|
93 |
+
# Either --content or --content_dir should be given.
|
94 |
+
if args.content:
|
95 |
+
content_paths = [Path(args.content)]
|
96 |
+
else:
|
97 |
+
content_dir = Path(args.content_dir)
|
98 |
+
content_paths = [f for f in content_dir.glob('*')]
|
99 |
+
|
100 |
+
# Either --style or --style_dir should be given.
|
101 |
+
if args.style:
|
102 |
+
style_paths = [Path(args.style)]
|
103 |
+
else:
|
104 |
+
style_dir = Path(args.style_dir)
|
105 |
+
style_paths = [f for f in style_dir.glob('*')]
|
106 |
+
|
107 |
+
if not os.path.exists(output_path):
|
108 |
+
os.mkdir(output_path)
|
109 |
+
|
110 |
+
|
111 |
+
vgg = StyTR.vgg
|
112 |
+
vgg.load_state_dict(torch.load(args.vgg))
|
113 |
+
vgg = nn.Sequential(*list(vgg.children())[:44])
|
114 |
+
|
115 |
+
decoder = StyTR.decoder
|
116 |
+
Trans = transformer.Transformer()
|
117 |
+
embedding = StyTR.PatchEmbed()
|
118 |
+
|
119 |
+
decoder.eval()
|
120 |
+
Trans.eval()
|
121 |
+
vgg.eval()
|
122 |
+
from collections import OrderedDict
|
123 |
+
new_state_dict = OrderedDict()
|
124 |
+
state_dict = torch.load(args.decoder_path)
|
125 |
+
for k, v in state_dict.items():
|
126 |
+
#namekey = k[7:] # remove `module.`
|
127 |
+
namekey = k
|
128 |
+
new_state_dict[namekey] = v
|
129 |
+
decoder.load_state_dict(new_state_dict)
|
130 |
+
|
131 |
+
new_state_dict = OrderedDict()
|
132 |
+
state_dict = torch.load(args.Trans_path)
|
133 |
+
for k, v in state_dict.items():
|
134 |
+
#namekey = k[7:] # remove `module.`
|
135 |
+
namekey = k
|
136 |
+
new_state_dict[namekey] = v
|
137 |
+
Trans.load_state_dict(new_state_dict)
|
138 |
+
|
139 |
+
new_state_dict = OrderedDict()
|
140 |
+
state_dict = torch.load(args.embedding_path)
|
141 |
+
for k, v in state_dict.items():
|
142 |
+
#namekey = k[7:] # remove `module.`
|
143 |
+
namekey = k
|
144 |
+
new_state_dict[namekey] = v
|
145 |
+
embedding.load_state_dict(new_state_dict)
|
146 |
+
|
147 |
+
network = StyTR.StyTrans(vgg,decoder,embedding,Trans,args)
|
148 |
+
network.eval()
|
149 |
+
network.to(device)
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
content_tf = test_transform(content_size, crop)
|
154 |
+
style_tf = test_transform(style_size, crop)
|
155 |
+
|
156 |
+
for content_path in content_paths:
|
157 |
+
for style_path in style_paths:
|
158 |
+
print(content_path)
|
159 |
+
|
160 |
+
|
161 |
+
content_tf1 = content_transform()
|
162 |
+
content = content_tf(Image.open(content_path).convert("RGB"))
|
163 |
+
|
164 |
+
h,w,c=np.shape(content)
|
165 |
+
style_tf1 = style_transform(h,w)
|
166 |
+
style = style_tf(Image.open(style_path).convert("RGB"))
|
167 |
+
|
168 |
+
|
169 |
+
style = style.to(device).unsqueeze(0)
|
170 |
+
content = content.to(device).unsqueeze(0)
|
171 |
+
|
172 |
+
with torch.no_grad():
|
173 |
+
output = network(content, style)[0]
|
174 |
+
output = output.cpu()
|
175 |
+
|
176 |
+
output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
|
177 |
+
output_path, splitext(basename(content_path))[0],
|
178 |
+
splitext(basename(style_path))[0], save_ext
|
179 |
+
)
|
180 |
+
|
181 |
+
save_image(output, output_name)
|
182 |
+
|
183 |
+
|
train.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.utils.data as data
|
6 |
+
from PIL import Image
|
7 |
+
from PIL import ImageFile
|
8 |
+
from tensorboardX import SummaryWriter
|
9 |
+
from torchvision import transforms
|
10 |
+
from tqdm import tqdm
|
11 |
+
from pathlib import Path
|
12 |
+
import models.transformer as transformer
|
13 |
+
import models.StyTR as StyTR
|
14 |
+
from sampler import InfiniteSamplerWrapper
|
15 |
+
from torchvision.utils import save_image
|
16 |
+
|
17 |
+
|
18 |
+
def train_transform():
|
19 |
+
transform_list = [
|
20 |
+
transforms.Resize(size=(512, 512)),
|
21 |
+
transforms.RandomCrop(256),
|
22 |
+
transforms.ToTensor()
|
23 |
+
]
|
24 |
+
return transforms.Compose(transform_list)
|
25 |
+
|
26 |
+
|
27 |
+
class FlatFolderDataset(data.Dataset):
|
28 |
+
def __init__(self, root, transform):
|
29 |
+
super(FlatFolderDataset, self).__init__()
|
30 |
+
self.root = root
|
31 |
+
print(self.root)
|
32 |
+
self.path = os.listdir(self.root)
|
33 |
+
if os.path.isdir(os.path.join(self.root,self.path[0])):
|
34 |
+
self.paths = []
|
35 |
+
for file_name in os.listdir(self.root):
|
36 |
+
for file_name1 in os.listdir(os.path.join(self.root,file_name)):
|
37 |
+
self.paths.append(self.root+"/"+file_name+"/"+file_name1)
|
38 |
+
else:
|
39 |
+
self.paths = list(Path(self.root).glob('*'))
|
40 |
+
self.transform = transform
|
41 |
+
def __getitem__(self, index):
|
42 |
+
path = self.paths[index]
|
43 |
+
img = Image.open(str(path)).convert('RGB')
|
44 |
+
img = self.transform(img)
|
45 |
+
return img
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.paths)
|
48 |
+
def name(self):
|
49 |
+
return 'FlatFolderDataset'
|
50 |
+
|
51 |
+
def adjust_learning_rate(optimizer, iteration_count):
|
52 |
+
"""Imitating the original implementation"""
|
53 |
+
lr = 2e-4 / (1.0 + args.lr_decay * (iteration_count - 1e4))
|
54 |
+
for param_group in optimizer.param_groups:
|
55 |
+
param_group['lr'] = lr
|
56 |
+
|
57 |
+
def warmup_learning_rate(optimizer, iteration_count):
|
58 |
+
"""Imitating the original implementation"""
|
59 |
+
lr = args.lr * 0.1 * (1.0 + 3e-4 * iteration_count)
|
60 |
+
# print(lr)
|
61 |
+
for param_group in optimizer.param_groups:
|
62 |
+
param_group['lr'] = lr
|
63 |
+
|
64 |
+
|
65 |
+
parser = argparse.ArgumentParser()
|
66 |
+
# Basic options
|
67 |
+
parser.add_argument('--content_dir', default=r'E:\NLP\VAL_Transformers\models\StyTr2\images', type=str,
|
68 |
+
help='Directory path to a batch of content images')
|
69 |
+
parser.add_argument('--style_dir', default=r'E:\NLP\VAL_Transformers\models\StyTr2\style', type=str, #wikiart dataset crawled from https://www.wikiart.org/
|
70 |
+
help='Directory path to a batch of style images')
|
71 |
+
parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth') #run the train.py, please download the pretrained vgg checkpoint
|
72 |
+
|
73 |
+
# training options
|
74 |
+
parser.add_argument('--save_dir', default='./experiments',
|
75 |
+
help='Directory to save the model')
|
76 |
+
parser.add_argument('--log_dir', default='./logs',
|
77 |
+
help='Directory to save the log')
|
78 |
+
parser.add_argument('--lr', type=float, default=5e-4)
|
79 |
+
parser.add_argument('--lr_decay', type=float, default=1e-5)
|
80 |
+
parser.add_argument('--max_iter', type=int, default=160000)
|
81 |
+
parser.add_argument('--batch_size', type=int, default=8)
|
82 |
+
parser.add_argument('--style_weight', type=float, default=10.0)
|
83 |
+
parser.add_argument('--content_weight', type=float, default=7.0)
|
84 |
+
parser.add_argument('--n_threads', type=int, default=1)
|
85 |
+
parser.add_argument('--save_model_interval', type=int, default=10000)
|
86 |
+
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
87 |
+
help="Type of positional embedding to use on top of the image features")
|
88 |
+
parser.add_argument('--hidden_dim', default=512, type=int,
|
89 |
+
help="Size of the embeddings (dimension of the transformer)")
|
90 |
+
args = parser.parse_args()
|
91 |
+
|
92 |
+
USE_CUDA = torch.cuda.is_available()
|
93 |
+
device = torch.device("cuda" if USE_CUDA else "cpu")
|
94 |
+
print(device)
|
95 |
+
|
96 |
+
if not os.path.exists(args.save_dir):
|
97 |
+
os.makedirs(args.save_dir)
|
98 |
+
|
99 |
+
if not os.path.exists(args.log_dir):
|
100 |
+
os.mkdir(args.log_dir)
|
101 |
+
writer = SummaryWriter(log_dir=args.log_dir)
|
102 |
+
|
103 |
+
vgg = StyTR.vgg
|
104 |
+
vgg.load_state_dict(torch.load(args.vgg))
|
105 |
+
vgg = nn.Sequential(*list(vgg.children())[:44])
|
106 |
+
|
107 |
+
decoder = StyTR.decoder
|
108 |
+
embedding = StyTR.PatchEmbed()
|
109 |
+
|
110 |
+
Trans = transformer.Transformer()
|
111 |
+
with torch.no_grad():
|
112 |
+
network = StyTR.StyTrans(vgg,decoder,embedding, Trans,args)
|
113 |
+
network.train()
|
114 |
+
|
115 |
+
network.to(device)
|
116 |
+
network = nn.DataParallel(network, device_ids=[0,1])
|
117 |
+
content_tf = train_transform()
|
118 |
+
style_tf = train_transform()
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
content_dataset = FlatFolderDataset(args.content_dir, content_tf)
|
123 |
+
style_dataset = FlatFolderDataset(args.style_dir, style_tf)
|
124 |
+
|
125 |
+
content_iter = iter(data.DataLoader(
|
126 |
+
content_dataset, batch_size=args.batch_size,
|
127 |
+
sampler=InfiniteSamplerWrapper(content_dataset),
|
128 |
+
num_workers=args.n_threads))
|
129 |
+
style_iter = iter(data.DataLoader(
|
130 |
+
style_dataset, batch_size=args.batch_size,
|
131 |
+
sampler=InfiniteSamplerWrapper(style_dataset),
|
132 |
+
num_workers=args.n_threads))
|
133 |
+
|
134 |
+
|
135 |
+
optimizer = torch.optim.Adam([
|
136 |
+
{'params': network.module.transformer.parameters()},
|
137 |
+
{'params': network.module.decode.parameters()},
|
138 |
+
{'params': network.module.embedding.parameters()},
|
139 |
+
], lr=args.lr)
|
140 |
+
|
141 |
+
|
142 |
+
if not os.path.exists(args.save_dir+"/test"):
|
143 |
+
os.makedirs(args.save_dir+"/test")
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
for i in tqdm(range(args.max_iter)):
|
148 |
+
|
149 |
+
if i < 1e4:
|
150 |
+
warmup_learning_rate(optimizer, iteration_count=i)
|
151 |
+
else:
|
152 |
+
adjust_learning_rate(optimizer, iteration_count=i)
|
153 |
+
|
154 |
+
# print('learning_rate: %s' % str(optimizer.param_groups[0]['lr']))
|
155 |
+
content_images = next(content_iter).to(device)
|
156 |
+
style_images = next(style_iter).to(device)
|
157 |
+
out, loss_c, loss_s,l_identity1, l_identity2 = network(content_images, style_images)
|
158 |
+
|
159 |
+
if i % 100 == 0:
|
160 |
+
output_name = '{:s}/test/{:s}{:s}'.format(
|
161 |
+
args.save_dir, str(i),".jpg"
|
162 |
+
)
|
163 |
+
out = torch.cat((content_images,out),0)
|
164 |
+
out = torch.cat((style_images,out),0)
|
165 |
+
save_image(out, output_name)
|
166 |
+
|
167 |
+
|
168 |
+
loss_c = args.content_weight * loss_c
|
169 |
+
loss_s = args.style_weight * loss_s
|
170 |
+
loss = loss_c + loss_s + (l_identity1 * 70) + (l_identity2 * 1)
|
171 |
+
|
172 |
+
print(loss.sum().cpu().detach().numpy(),"-content:",loss_c.sum().cpu().detach().numpy(),"-style:",loss_s.sum().cpu().detach().numpy()
|
173 |
+
,"-l1:",l_identity1.sum().cpu().detach().numpy(),"-l2:",l_identity2.sum().cpu().detach().numpy()
|
174 |
+
)
|
175 |
+
|
176 |
+
optimizer.zero_grad()
|
177 |
+
loss.sum().backward()
|
178 |
+
optimizer.step()
|
179 |
+
|
180 |
+
writer.add_scalar('loss_content', loss_c.sum().item(), i + 1)
|
181 |
+
writer.add_scalar('loss_style', loss_s.sum().item(), i + 1)
|
182 |
+
writer.add_scalar('loss_identity1', l_identity1.sum().item(), i + 1)
|
183 |
+
writer.add_scalar('loss_identity2', l_identity2.sum().item(), i + 1)
|
184 |
+
writer.add_scalar('total_loss', loss.sum().item(), i + 1)
|
185 |
+
|
186 |
+
if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
|
187 |
+
state_dict = network.module.transformer.state_dict()
|
188 |
+
for key in state_dict.keys():
|
189 |
+
state_dict[key] = state_dict[key].to(torch.device('cpu'))
|
190 |
+
torch.save(state_dict,
|
191 |
+
'{:s}/transformer_iter_{:d}.pth'.format(args.save_dir,
|
192 |
+
i + 1))
|
193 |
+
|
194 |
+
state_dict = network.module.decode.state_dict()
|
195 |
+
for key in state_dict.keys():
|
196 |
+
state_dict[key] = state_dict[key].to(torch.device('cpu'))
|
197 |
+
torch.save(state_dict,
|
198 |
+
'{:s}/decoder_iter_{:d}.pth'.format(args.save_dir,
|
199 |
+
i + 1))
|
200 |
+
state_dict = network.module.embedding.state_dict()
|
201 |
+
for key in state_dict.keys():
|
202 |
+
state_dict[key] = state_dict[key].to(torch.device('cpu'))
|
203 |
+
torch.save(state_dict,
|
204 |
+
'{:s}/embedding_iter_{:d}.pth'.format(args.save_dir,
|
205 |
+
i + 1))
|
206 |
+
|
207 |
+
|
208 |
+
writer.close()
|
209 |
+
|
210 |
+
|
util/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
util/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (148 Bytes). View file
|
|
util/__pycache__/box_ops.cpython-310.pyc
ADDED
Binary file (2.72 kB). View file
|
|
util/__pycache__/misc.cpython-310.pyc
ADDED
Binary file (14.6 kB). View file
|
|
util/box_ops.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Utilities for bounding box manipulation and GIoU.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
from torchvision.ops.boxes import box_area
|
7 |
+
|
8 |
+
|
9 |
+
def box_cxcywh_to_xyxy(x):
|
10 |
+
x_c, y_c, w, h = x.unbind(-1)
|
11 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
12 |
+
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
13 |
+
return torch.stack(b, dim=-1)
|
14 |
+
|
15 |
+
|
16 |
+
def box_xyxy_to_cxcywh(x):
|
17 |
+
x0, y0, x1, y1 = x.unbind(-1)
|
18 |
+
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
19 |
+
(x1 - x0), (y1 - y0)]
|
20 |
+
return torch.stack(b, dim=-1)
|
21 |
+
|
22 |
+
|
23 |
+
# modified from torchvision to also return the union
|
24 |
+
def box_iou(boxes1, boxes2):
|
25 |
+
area1 = box_area(boxes1)
|
26 |
+
area2 = box_area(boxes2)
|
27 |
+
|
28 |
+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
29 |
+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
30 |
+
|
31 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
32 |
+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
33 |
+
|
34 |
+
union = area1[:, None] + area2 - inter
|
35 |
+
|
36 |
+
iou = inter / union
|
37 |
+
return iou, union
|
38 |
+
|
39 |
+
|
40 |
+
def generalized_box_iou(boxes1, boxes2):
|
41 |
+
"""
|
42 |
+
Generalized IoU from https://giou.stanford.edu/
|
43 |
+
|
44 |
+
The boxes should be in [x0, y0, x1, y1] format
|
45 |
+
|
46 |
+
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
47 |
+
and M = len(boxes2)
|
48 |
+
"""
|
49 |
+
# degenerate boxes gives inf / nan results
|
50 |
+
# so do an early check
|
51 |
+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
52 |
+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
53 |
+
iou, union = box_iou(boxes1, boxes2)
|
54 |
+
|
55 |
+
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
56 |
+
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
57 |
+
|
58 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
59 |
+
area = wh[:, :, 0] * wh[:, :, 1]
|
60 |
+
|
61 |
+
return iou - (area - union) / area
|
62 |
+
|
63 |
+
|
64 |
+
def masks_to_boxes(masks):
|
65 |
+
"""Compute the bounding boxes around the provided masks
|
66 |
+
|
67 |
+
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
68 |
+
|
69 |
+
Returns a [N, 4] tensors, with the boxes in xyxy format
|
70 |
+
"""
|
71 |
+
if masks.numel() == 0:
|
72 |
+
return torch.zeros((0, 4), device=masks.device)
|
73 |
+
|
74 |
+
h, w = masks.shape[-2:]
|
75 |
+
|
76 |
+
y = torch.arange(0, h, dtype=torch.float)
|
77 |
+
x = torch.arange(0, w, dtype=torch.float)
|
78 |
+
y, x = torch.meshgrid(y, x)
|
79 |
+
|
80 |
+
x_mask = (masks * x.unsqueeze(0))
|
81 |
+
x_max = x_mask.flatten(1).max(-1)[0]
|
82 |
+
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
83 |
+
|
84 |
+
y_mask = (masks * y.unsqueeze(0))
|
85 |
+
y_max = y_mask.flatten(1).max(-1)[0]
|
86 |
+
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
87 |
+
|
88 |
+
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
util/misc.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Misc functions, including distributed helpers.
|
4 |
+
|
5 |
+
Mostly copy-paste from torchvision references.
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
import subprocess
|
9 |
+
import time
|
10 |
+
from collections import defaultdict, deque
|
11 |
+
import datetime
|
12 |
+
import pickle
|
13 |
+
from packaging import version
|
14 |
+
from typing import Optional, List
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.distributed as dist
|
18 |
+
from torch import Tensor
|
19 |
+
|
20 |
+
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
21 |
+
import torchvision
|
22 |
+
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
23 |
+
from torchvision.ops import _new_empty_tensor
|
24 |
+
from torchvision.ops.misc import _output_size
|
25 |
+
|
26 |
+
|
27 |
+
class SmoothedValue(object):
|
28 |
+
"""Track a series of values and provide access to smoothed values over a
|
29 |
+
window or the global series average.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, window_size=20, fmt=None):
|
33 |
+
if fmt is None:
|
34 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
35 |
+
self.deque = deque(maxlen=window_size)
|
36 |
+
self.total = 0.0
|
37 |
+
self.count = 0
|
38 |
+
self.fmt = fmt
|
39 |
+
|
40 |
+
def update(self, value, n=1):
|
41 |
+
self.deque.append(value)
|
42 |
+
self.count += n
|
43 |
+
self.total += value * n
|
44 |
+
|
45 |
+
def synchronize_between_processes(self):
|
46 |
+
"""
|
47 |
+
Warning: does not synchronize the deque!
|
48 |
+
"""
|
49 |
+
if not is_dist_avail_and_initialized():
|
50 |
+
return
|
51 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
52 |
+
dist.barrier()
|
53 |
+
dist.all_reduce(t)
|
54 |
+
t = t.tolist()
|
55 |
+
self.count = int(t[0])
|
56 |
+
self.total = t[1]
|
57 |
+
|
58 |
+
@property
|
59 |
+
def median(self):
|
60 |
+
d = torch.tensor(list(self.deque))
|
61 |
+
return d.median().item()
|
62 |
+
|
63 |
+
@property
|
64 |
+
def avg(self):
|
65 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
66 |
+
return d.mean().item()
|
67 |
+
|
68 |
+
@property
|
69 |
+
def global_avg(self):
|
70 |
+
return self.total / self.count
|
71 |
+
|
72 |
+
@property
|
73 |
+
def max(self):
|
74 |
+
return max(self.deque)
|
75 |
+
|
76 |
+
@property
|
77 |
+
def value(self):
|
78 |
+
return self.deque[-1]
|
79 |
+
|
80 |
+
def __str__(self):
|
81 |
+
return self.fmt.format(
|
82 |
+
median=self.median,
|
83 |
+
avg=self.avg,
|
84 |
+
global_avg=self.global_avg,
|
85 |
+
max=self.max,
|
86 |
+
value=self.value)
|
87 |
+
|
88 |
+
|
89 |
+
def all_gather(data):
|
90 |
+
"""
|
91 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
92 |
+
Args:
|
93 |
+
data: any picklable object
|
94 |
+
Returns:
|
95 |
+
list[data]: list of data gathered from each rank
|
96 |
+
"""
|
97 |
+
world_size = get_world_size()
|
98 |
+
if world_size == 1:
|
99 |
+
return [data]
|
100 |
+
|
101 |
+
# serialized to a Tensor
|
102 |
+
buffer = pickle.dumps(data)
|
103 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
104 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
105 |
+
|
106 |
+
# obtain Tensor size of each rank
|
107 |
+
local_size = torch.tensor([tensor.numel()], device="cuda")
|
108 |
+
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
109 |
+
dist.all_gather(size_list, local_size)
|
110 |
+
size_list = [int(size.item()) for size in size_list]
|
111 |
+
max_size = max(size_list)
|
112 |
+
|
113 |
+
# receiving Tensor from all ranks
|
114 |
+
# we pad the tensor because torch all_gather does not support
|
115 |
+
# gathering tensors of different shapes
|
116 |
+
tensor_list = []
|
117 |
+
for _ in size_list:
|
118 |
+
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
119 |
+
if local_size != max_size:
|
120 |
+
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
121 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
122 |
+
dist.all_gather(tensor_list, tensor)
|
123 |
+
|
124 |
+
data_list = []
|
125 |
+
for size, tensor in zip(size_list, tensor_list):
|
126 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
127 |
+
data_list.append(pickle.loads(buffer))
|
128 |
+
|
129 |
+
return data_list
|
130 |
+
|
131 |
+
|
132 |
+
def reduce_dict(input_dict, average=True):
|
133 |
+
"""
|
134 |
+
Args:
|
135 |
+
input_dict (dict): all the values will be reduced
|
136 |
+
average (bool): whether to do average or sum
|
137 |
+
Reduce the values in the dictionary from all processes so that all processes
|
138 |
+
have the averaged results. Returns a dict with the same fields as
|
139 |
+
input_dict, after reduction.
|
140 |
+
"""
|
141 |
+
world_size = get_world_size()
|
142 |
+
if world_size < 2:
|
143 |
+
return input_dict
|
144 |
+
with torch.no_grad():
|
145 |
+
names = []
|
146 |
+
values = []
|
147 |
+
# sort the keys so that they are consistent across processes
|
148 |
+
for k in sorted(input_dict.keys()):
|
149 |
+
names.append(k)
|
150 |
+
values.append(input_dict[k])
|
151 |
+
values = torch.stack(values, dim=0)
|
152 |
+
dist.all_reduce(values)
|
153 |
+
if average:
|
154 |
+
values /= world_size
|
155 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
156 |
+
return reduced_dict
|
157 |
+
|
158 |
+
|
159 |
+
class MetricLogger(object):
|
160 |
+
def __init__(self, delimiter="\t"):
|
161 |
+
self.meters = defaultdict(SmoothedValue)
|
162 |
+
self.delimiter = delimiter
|
163 |
+
|
164 |
+
def update(self, **kwargs):
|
165 |
+
for k, v in kwargs.items():
|
166 |
+
if isinstance(v, torch.Tensor):
|
167 |
+
v = v.item()
|
168 |
+
assert isinstance(v, (float, int))
|
169 |
+
self.meters[k].update(v)
|
170 |
+
|
171 |
+
def __getattr__(self, attr):
|
172 |
+
if attr in self.meters:
|
173 |
+
return self.meters[attr]
|
174 |
+
if attr in self.__dict__:
|
175 |
+
return self.__dict__[attr]
|
176 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
177 |
+
type(self).__name__, attr))
|
178 |
+
|
179 |
+
def __str__(self):
|
180 |
+
loss_str = []
|
181 |
+
for name, meter in self.meters.items():
|
182 |
+
loss_str.append(
|
183 |
+
"{}: {}".format(name, str(meter))
|
184 |
+
)
|
185 |
+
return self.delimiter.join(loss_str)
|
186 |
+
|
187 |
+
def synchronize_between_processes(self):
|
188 |
+
for meter in self.meters.values():
|
189 |
+
meter.synchronize_between_processes()
|
190 |
+
|
191 |
+
def add_meter(self, name, meter):
|
192 |
+
self.meters[name] = meter
|
193 |
+
|
194 |
+
def log_every(self, iterable, print_freq, header=None):
|
195 |
+
i = 0
|
196 |
+
if not header:
|
197 |
+
header = ''
|
198 |
+
start_time = time.time()
|
199 |
+
end = time.time()
|
200 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
201 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
202 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
203 |
+
if torch.cuda.is_available():
|
204 |
+
log_msg = self.delimiter.join([
|
205 |
+
header,
|
206 |
+
'[{0' + space_fmt + '}/{1}]',
|
207 |
+
'eta: {eta}',
|
208 |
+
'{meters}',
|
209 |
+
'time: {time}',
|
210 |
+
'data: {data}',
|
211 |
+
'max mem: {memory:.0f}'
|
212 |
+
])
|
213 |
+
else:
|
214 |
+
log_msg = self.delimiter.join([
|
215 |
+
header,
|
216 |
+
'[{0' + space_fmt + '}/{1}]',
|
217 |
+
'eta: {eta}',
|
218 |
+
'{meters}',
|
219 |
+
'time: {time}',
|
220 |
+
'data: {data}'
|
221 |
+
])
|
222 |
+
MB = 1024.0 * 1024.0
|
223 |
+
for obj in iterable:
|
224 |
+
data_time.update(time.time() - end)
|
225 |
+
yield obj
|
226 |
+
iter_time.update(time.time() - end)
|
227 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
228 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
229 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
230 |
+
if torch.cuda.is_available():
|
231 |
+
print(log_msg.format(
|
232 |
+
i, len(iterable), eta=eta_string,
|
233 |
+
meters=str(self),
|
234 |
+
time=str(iter_time), data=str(data_time),
|
235 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
236 |
+
else:
|
237 |
+
print(log_msg.format(
|
238 |
+
i, len(iterable), eta=eta_string,
|
239 |
+
meters=str(self),
|
240 |
+
time=str(iter_time), data=str(data_time)))
|
241 |
+
i += 1
|
242 |
+
end = time.time()
|
243 |
+
total_time = time.time() - start_time
|
244 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
245 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
246 |
+
header, total_time_str, total_time / len(iterable)))
|
247 |
+
|
248 |
+
|
249 |
+
def get_sha():
|
250 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
251 |
+
|
252 |
+
def _run(command):
|
253 |
+
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
254 |
+
sha = 'N/A'
|
255 |
+
diff = "clean"
|
256 |
+
branch = 'N/A'
|
257 |
+
try:
|
258 |
+
sha = _run(['git', 'rev-parse', 'HEAD'])
|
259 |
+
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
260 |
+
diff = _run(['git', 'diff-index', 'HEAD'])
|
261 |
+
diff = "has uncommited changes" if diff else "clean"
|
262 |
+
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
263 |
+
except Exception:
|
264 |
+
pass
|
265 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
266 |
+
return message
|
267 |
+
|
268 |
+
|
269 |
+
def collate_fn(batch):
|
270 |
+
batch = list(zip(*batch))
|
271 |
+
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
272 |
+
return tuple(batch)
|
273 |
+
|
274 |
+
|
275 |
+
def _max_by_axis(the_list):
|
276 |
+
# type: (List[List[int]]) -> List[int]
|
277 |
+
maxes = the_list[0]
|
278 |
+
for sublist in the_list[1:]:
|
279 |
+
for index, item in enumerate(sublist):
|
280 |
+
maxes[index] = max(maxes[index], item)
|
281 |
+
return maxes
|
282 |
+
|
283 |
+
|
284 |
+
class NestedTensor(object):
|
285 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
286 |
+
self.tensors = tensors
|
287 |
+
self.mask = mask
|
288 |
+
|
289 |
+
def to(self, device):
|
290 |
+
# type: (Device) -> NestedTensor # noqa
|
291 |
+
cast_tensor = self.tensors.to(device)
|
292 |
+
mask = self.mask
|
293 |
+
if mask is not None:
|
294 |
+
assert mask is not None
|
295 |
+
cast_mask = mask.to(device)
|
296 |
+
else:
|
297 |
+
cast_mask = None
|
298 |
+
return NestedTensor(cast_tensor, cast_mask)
|
299 |
+
|
300 |
+
def decompose(self):
|
301 |
+
return self.tensors, self.mask
|
302 |
+
|
303 |
+
def __repr__(self):
|
304 |
+
return str(self.tensors)
|
305 |
+
|
306 |
+
|
307 |
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
308 |
+
# TODO make this more general
|
309 |
+
if tensor_list[0].ndim == 3:
|
310 |
+
if torchvision._is_tracing():
|
311 |
+
# nested_tensor_from_tensor_list() does not export well to ONNX
|
312 |
+
# call _onnx_nested_tensor_from_tensor_list() instead
|
313 |
+
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
314 |
+
|
315 |
+
# TODO make it support different-sized images
|
316 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
317 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
318 |
+
batch_shape = [len(tensor_list)] + max_size
|
319 |
+
b, c, h, w = batch_shape
|
320 |
+
dtype = tensor_list[0].dtype
|
321 |
+
device = tensor_list[0].device
|
322 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
323 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
324 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
325 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
326 |
+
m[: img.shape[1], :img.shape[2]] = False
|
327 |
+
else:
|
328 |
+
raise ValueError('not supported')
|
329 |
+
return NestedTensor(tensor, mask)
|
330 |
+
|
331 |
+
|
332 |
+
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
333 |
+
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
334 |
+
@torch.jit.unused
|
335 |
+
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
336 |
+
max_size = []
|
337 |
+
for i in range(tensor_list[0].dim()):
|
338 |
+
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
339 |
+
max_size.append(max_size_i)
|
340 |
+
max_size = tuple(max_size)
|
341 |
+
|
342 |
+
# work around for
|
343 |
+
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
344 |
+
# m[: img.shape[1], :img.shape[2]] = False
|
345 |
+
# which is not yet supported in onnx
|
346 |
+
padded_imgs = []
|
347 |
+
padded_masks = []
|
348 |
+
for img in tensor_list:
|
349 |
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
350 |
+
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
351 |
+
padded_imgs.append(padded_img)
|
352 |
+
|
353 |
+
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
354 |
+
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
355 |
+
padded_masks.append(padded_mask.to(torch.bool))
|
356 |
+
|
357 |
+
tensor = torch.stack(padded_imgs)
|
358 |
+
mask = torch.stack(padded_masks)
|
359 |
+
|
360 |
+
return NestedTensor(tensor, mask=mask)
|
361 |
+
|
362 |
+
|
363 |
+
def setup_for_distributed(is_master):
|
364 |
+
"""
|
365 |
+
This function disables printing when not in master process
|
366 |
+
"""
|
367 |
+
import builtins as __builtin__
|
368 |
+
builtin_print = __builtin__.print
|
369 |
+
|
370 |
+
def print(*args, **kwargs):
|
371 |
+
force = kwargs.pop('force', False)
|
372 |
+
if is_master or force:
|
373 |
+
builtin_print(*args, **kwargs)
|
374 |
+
|
375 |
+
__builtin__.print = print
|
376 |
+
|
377 |
+
|
378 |
+
def is_dist_avail_and_initialized():
|
379 |
+
if not dist.is_available():
|
380 |
+
return False
|
381 |
+
if not dist.is_initialized():
|
382 |
+
return False
|
383 |
+
return True
|
384 |
+
|
385 |
+
|
386 |
+
def get_world_size():
|
387 |
+
if not is_dist_avail_and_initialized():
|
388 |
+
return 1
|
389 |
+
return dist.get_world_size()
|
390 |
+
|
391 |
+
|
392 |
+
def get_rank():
|
393 |
+
if not is_dist_avail_and_initialized():
|
394 |
+
return 0
|
395 |
+
return dist.get_rank()
|
396 |
+
|
397 |
+
|
398 |
+
def is_main_process():
|
399 |
+
return get_rank() == 0
|
400 |
+
|
401 |
+
|
402 |
+
def save_on_master(*args, **kwargs):
|
403 |
+
if is_main_process():
|
404 |
+
torch.save(*args, **kwargs)
|
405 |
+
|
406 |
+
|
407 |
+
def init_distributed_mode(args):
|
408 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
409 |
+
args.rank = int(os.environ["RANK"])
|
410 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
411 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
412 |
+
elif 'SLURM_PROCID' in os.environ:
|
413 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
414 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
415 |
+
else:
|
416 |
+
print('Not using distributed mode')
|
417 |
+
args.distributed = False
|
418 |
+
return
|
419 |
+
|
420 |
+
args.distributed = True
|
421 |
+
|
422 |
+
torch.cuda.set_device(args.gpu)
|
423 |
+
args.dist_backend = 'nccl'
|
424 |
+
print('| distributed init (rank {}): {}'.format(
|
425 |
+
args.rank, args.dist_url), flush=True)
|
426 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
427 |
+
world_size=args.world_size, rank=args.rank)
|
428 |
+
torch.distributed.barrier()
|
429 |
+
setup_for_distributed(args.rank == 0)
|
430 |
+
|
431 |
+
|
432 |
+
@torch.no_grad()
|
433 |
+
def accuracy(output, target, topk=(1,)):
|
434 |
+
"""Computes the precision@k for the specified values of k"""
|
435 |
+
if target.numel() == 0:
|
436 |
+
return [torch.zeros([], device=output.device)]
|
437 |
+
maxk = max(topk)
|
438 |
+
batch_size = target.size(0)
|
439 |
+
|
440 |
+
_, pred = output.topk(maxk, 1, True, True)
|
441 |
+
pred = pred.t()
|
442 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
443 |
+
|
444 |
+
res = []
|
445 |
+
for k in topk:
|
446 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
447 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
448 |
+
return res
|
449 |
+
|
450 |
+
|
451 |
+
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
452 |
+
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
453 |
+
"""
|
454 |
+
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
455 |
+
This will eventually be supported natively by PyTorch, and this
|
456 |
+
class can go away.
|
457 |
+
"""
|
458 |
+
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
459 |
+
if input.numel() > 0:
|
460 |
+
return torch.nn.functional.interpolate(
|
461 |
+
input, size, scale_factor, mode, align_corners
|
462 |
+
)
|
463 |
+
|
464 |
+
output_shape = _output_size(2, input, size, scale_factor)
|
465 |
+
output_shape = list(input.shape[:-2]) + list(output_shape)
|
466 |
+
return _new_empty_tensor(input, output_shape)
|
467 |
+
else:
|
468 |
+
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
util/plot_utils.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Plotting utilities to visualize training logs.
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
import seaborn as sns
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from pathlib import Path, PurePath
|
11 |
+
|
12 |
+
|
13 |
+
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
14 |
+
'''
|
15 |
+
Function to plot specific fields from training log(s). Plots both training and test results.
|
16 |
+
|
17 |
+
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
18 |
+
- fields = which results to plot from each log file - plots both training and test for each field.
|
19 |
+
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
20 |
+
- log_name = optional, name of log file if different than default 'log.txt'.
|
21 |
+
|
22 |
+
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
23 |
+
- solid lines are training results, dashed lines are test results.
|
24 |
+
|
25 |
+
'''
|
26 |
+
func_name = "plot_utils.py::plot_logs"
|
27 |
+
|
28 |
+
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
29 |
+
# convert single Path to list to avoid 'not iterable' error
|
30 |
+
|
31 |
+
if not isinstance(logs, list):
|
32 |
+
if isinstance(logs, PurePath):
|
33 |
+
logs = [logs]
|
34 |
+
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
35 |
+
else:
|
36 |
+
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
37 |
+
Expect list[Path] or single Path obj, received {type(logs)}")
|
38 |
+
|
39 |
+
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
40 |
+
for i, dir in enumerate(logs):
|
41 |
+
if not isinstance(dir, PurePath):
|
42 |
+
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
43 |
+
if not dir.exists():
|
44 |
+
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
45 |
+
# verify log_name exists
|
46 |
+
fn = Path(dir / log_name)
|
47 |
+
if not fn.exists():
|
48 |
+
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
49 |
+
print(f"--> full path of missing log file: {fn}")
|
50 |
+
return
|
51 |
+
|
52 |
+
# load log file(s) and plot
|
53 |
+
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
54 |
+
|
55 |
+
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
56 |
+
|
57 |
+
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
58 |
+
for j, field in enumerate(fields):
|
59 |
+
if field == 'mAP':
|
60 |
+
coco_eval = pd.DataFrame(
|
61 |
+
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
62 |
+
).ewm(com=ewm_col).mean()
|
63 |
+
axs[j].plot(coco_eval, c=color)
|
64 |
+
else:
|
65 |
+
df.interpolate().ewm(com=ewm_col).mean().plot(
|
66 |
+
y=[f'train_{field}', f'test_{field}'],
|
67 |
+
ax=axs[j],
|
68 |
+
color=[color] * 2,
|
69 |
+
style=['-', '--']
|
70 |
+
)
|
71 |
+
for ax, field in zip(axs, fields):
|
72 |
+
ax.legend([Path(p).name for p in logs])
|
73 |
+
ax.set_title(field)
|
74 |
+
|
75 |
+
|
76 |
+
def plot_precision_recall(files, naming_scheme='iter'):
|
77 |
+
if naming_scheme == 'exp_id':
|
78 |
+
# name becomes exp_id
|
79 |
+
names = [f.parts[-3] for f in files]
|
80 |
+
elif naming_scheme == 'iter':
|
81 |
+
names = [f.stem for f in files]
|
82 |
+
else:
|
83 |
+
raise ValueError(f'not supported {naming_scheme}')
|
84 |
+
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
85 |
+
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
86 |
+
data = torch.load(f)
|
87 |
+
# precision is n_iou, n_points, n_cat, n_area, max_det
|
88 |
+
precision = data['precision']
|
89 |
+
recall = data['params'].recThrs
|
90 |
+
scores = data['scores']
|
91 |
+
# take precision for all classes, all areas and 100 detections
|
92 |
+
precision = precision[0, :, :, 0, -1].mean(1)
|
93 |
+
scores = scores[0, :, :, 0, -1].mean(1)
|
94 |
+
prec = precision.mean()
|
95 |
+
rec = data['recall'][0, :, 0, -1].mean()
|
96 |
+
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
97 |
+
f'score={scores.mean():0.3f}, ' +
|
98 |
+
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
99 |
+
)
|
100 |
+
axs[0].plot(recall, precision, c=color)
|
101 |
+
axs[1].plot(recall, scores, c=color)
|
102 |
+
|
103 |
+
axs[0].set_title('Precision / Recall')
|
104 |
+
axs[0].legend(names)
|
105 |
+
axs[1].set_title('Scores / Recall')
|
106 |
+
axs[1].legend(names)
|
107 |
+
return fig, axs
|