Spaces:
Running
Running
Upload 13 files
Browse files- models/__init__.py +3 -0
- models/__pycache__/__init__.cpython-39.pyc +0 -0
- models/__pycache__/anime_gan.cpython-39.pyc +0 -0
- models/__pycache__/anime_gan_v2.cpython-39.pyc +0 -0
- models/__pycache__/anime_gan_v3.cpython-39.pyc +0 -0
- models/__pycache__/conv_blocks.cpython-39.pyc +0 -0
- models/__pycache__/layers.cpython-39.pyc +0 -0
- models/anime_gan.py +112 -0
- models/anime_gan_v2.py +61 -0
- models/anime_gan_v3.py +14 -0
- models/conv_blocks.py +185 -0
- models/layers.py +24 -0
- models/vgg.py +80 -0
models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .anime_gan import GeneratorV1
|
2 |
+
from .anime_gan_v2 import GeneratorV2
|
3 |
+
from .anime_gan_v3 import GeneratorV3
|
models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (285 Bytes). View file
|
|
models/__pycache__/anime_gan.cpython-39.pyc
ADDED
Binary file (2.82 kB). View file
|
|
models/__pycache__/anime_gan_v2.cpython-39.pyc
ADDED
Binary file (1.7 kB). View file
|
|
models/__pycache__/anime_gan_v3.cpython-39.pyc
ADDED
Binary file (698 Bytes). View file
|
|
models/__pycache__/conv_blocks.cpython-39.pyc
ADDED
Binary file (5.04 kB). View file
|
|
models/__pycache__/layers.cpython-39.pyc
ADDED
Binary file (1.25 kB). View file
|
|
models/anime_gan.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.utils import spectral_norm
|
5 |
+
from .conv_blocks import DownConv
|
6 |
+
from .conv_blocks import UpConv
|
7 |
+
from .conv_blocks import SeparableConv2D
|
8 |
+
from .conv_blocks import InvertedResBlock
|
9 |
+
from .conv_blocks import ConvBlock
|
10 |
+
from .layers import get_norm
|
11 |
+
from utils.common import initialize_weights
|
12 |
+
|
13 |
+
|
14 |
+
class GeneratorV1(nn.Module):
|
15 |
+
def __init__(self, dataset=''):
|
16 |
+
super(GeneratorV1, self).__init__()
|
17 |
+
self.name = f'{self.__class__.__name__}_{dataset}'
|
18 |
+
bias = False
|
19 |
+
|
20 |
+
self.encode_blocks = nn.Sequential(
|
21 |
+
ConvBlock(3, 64, bias=bias),
|
22 |
+
ConvBlock(64, 128, bias=bias),
|
23 |
+
DownConv(128, bias=bias),
|
24 |
+
ConvBlock(128, 128, bias=bias),
|
25 |
+
SeparableConv2D(128, 256, bias=bias),
|
26 |
+
DownConv(256, bias=bias),
|
27 |
+
ConvBlock(256, 256, bias=bias),
|
28 |
+
)
|
29 |
+
|
30 |
+
self.res_blocks = nn.Sequential(
|
31 |
+
InvertedResBlock(256, 256, bias=bias),
|
32 |
+
InvertedResBlock(256, 256, bias=bias),
|
33 |
+
InvertedResBlock(256, 256, bias=bias),
|
34 |
+
InvertedResBlock(256, 256, bias=bias),
|
35 |
+
InvertedResBlock(256, 256, bias=bias),
|
36 |
+
InvertedResBlock(256, 256, bias=bias),
|
37 |
+
InvertedResBlock(256, 256, bias=bias),
|
38 |
+
InvertedResBlock(256, 256, bias=bias),
|
39 |
+
)
|
40 |
+
|
41 |
+
self.decode_blocks = nn.Sequential(
|
42 |
+
ConvBlock(256, 128, bias=bias),
|
43 |
+
UpConv(128, bias=bias),
|
44 |
+
SeparableConv2D(128, 128, bias=bias),
|
45 |
+
ConvBlock(128, 128, bias=bias),
|
46 |
+
UpConv(128, bias=bias),
|
47 |
+
ConvBlock(128, 64, bias=bias),
|
48 |
+
ConvBlock(64, 64, bias=bias),
|
49 |
+
nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, bias=bias),
|
50 |
+
nn.Tanh(),
|
51 |
+
)
|
52 |
+
|
53 |
+
initialize_weights(self)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
out = self.encode_blocks(x)
|
57 |
+
out = self.res_blocks(out)
|
58 |
+
img = self.decode_blocks(out)
|
59 |
+
|
60 |
+
return img
|
61 |
+
|
62 |
+
|
63 |
+
class Discriminator(nn.Module):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
dataset=None,
|
67 |
+
num_layers=1,
|
68 |
+
use_sn=False,
|
69 |
+
norm_type="instance",
|
70 |
+
):
|
71 |
+
super(Discriminator, self).__init__()
|
72 |
+
self.name = f'discriminator_{dataset}'
|
73 |
+
self.bias = False
|
74 |
+
channels = 32
|
75 |
+
|
76 |
+
layers = [
|
77 |
+
nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
78 |
+
nn.LeakyReLU(0.2, True)
|
79 |
+
]
|
80 |
+
|
81 |
+
in_channels = channels
|
82 |
+
for i in range(num_layers):
|
83 |
+
layers += [
|
84 |
+
nn.Conv2d(in_channels, channels * 2, kernel_size=3, stride=2, padding=1, bias=self.bias),
|
85 |
+
nn.LeakyReLU(0.2, True),
|
86 |
+
nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
87 |
+
get_norm(norm_type)(channels * 4),
|
88 |
+
nn.LeakyReLU(0.2, True),
|
89 |
+
]
|
90 |
+
in_channels = channels * 4
|
91 |
+
channels *= 2
|
92 |
+
|
93 |
+
channels *= 2
|
94 |
+
layers += [
|
95 |
+
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
96 |
+
get_norm(norm_type)(channels),
|
97 |
+
nn.LeakyReLU(0.2, True),
|
98 |
+
nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
99 |
+
]
|
100 |
+
|
101 |
+
if use_sn:
|
102 |
+
for i in range(len(layers)):
|
103 |
+
if isinstance(layers[i], nn.Conv2d):
|
104 |
+
layers[i] = spectral_norm(layers[i])
|
105 |
+
|
106 |
+
self.discriminate = nn.Sequential(*layers)
|
107 |
+
|
108 |
+
initialize_weights(self)
|
109 |
+
|
110 |
+
def forward(self, img):
|
111 |
+
logits = self.discriminate(img)
|
112 |
+
return logits
|
models/anime_gan_v2.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.nn as nn
|
3 |
+
from models.conv_blocks import InvertedResBlock
|
4 |
+
from models.conv_blocks import ConvBlock
|
5 |
+
from models.conv_blocks import UpConvLNormLReLU
|
6 |
+
from utils.common import initialize_weights
|
7 |
+
|
8 |
+
|
9 |
+
class GeneratorV2(nn.Module):
|
10 |
+
def __init__(self, dataset=''):
|
11 |
+
super(GeneratorV2, self).__init__()
|
12 |
+
self.name = f'{self.__class__.__name__}_{dataset}'
|
13 |
+
bias = False
|
14 |
+
|
15 |
+
self.conv_block1 = nn.Sequential(
|
16 |
+
ConvBlock(3, 32, kernel_size=7, stride=1, norm_type="layer", bias=bias),
|
17 |
+
ConvBlock(32, 64, kernel_size=3, stride=2, norm_type="layer", bias=bias),
|
18 |
+
ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer", bias=bias),
|
19 |
+
)
|
20 |
+
|
21 |
+
self.conv_block2 = nn.Sequential(
|
22 |
+
ConvBlock(64, 128, kernel_size=3, stride=2, norm_type="layer", bias=bias),
|
23 |
+
ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer", bias=bias),
|
24 |
+
)
|
25 |
+
|
26 |
+
self.res_blocks = nn.Sequential(
|
27 |
+
ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer", bias=bias),
|
28 |
+
InvertedResBlock(128, 256, expand_ratio=2, norm_type="layer", bias=bias),
|
29 |
+
InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer", bias=bias),
|
30 |
+
InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer", bias=bias),
|
31 |
+
InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer", bias=bias),
|
32 |
+
ConvBlock(256, 128, kernel_size=3, stride=1, norm_type="layer", bias=bias),
|
33 |
+
)
|
34 |
+
|
35 |
+
self.upsample1 = nn.Sequential(
|
36 |
+
UpConvLNormLReLU(128, 128),
|
37 |
+
ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer", bias=bias),
|
38 |
+
)
|
39 |
+
|
40 |
+
self.upsample2 = nn.Sequential(
|
41 |
+
UpConvLNormLReLU(128, 64),
|
42 |
+
ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer", bias=bias),
|
43 |
+
ConvBlock(64, 32, kernel_size=7, stride=1, norm_type="layer", bias=bias),
|
44 |
+
)
|
45 |
+
|
46 |
+
self.decode_blocks = nn.Sequential(
|
47 |
+
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=bias),
|
48 |
+
nn.Tanh(),
|
49 |
+
)
|
50 |
+
|
51 |
+
initialize_weights(self)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
out = self.conv_block1(x)
|
55 |
+
out = self.conv_block2(out)
|
56 |
+
out = self.res_blocks(out)
|
57 |
+
out = self.upsample1(out)
|
58 |
+
out = self.upsample2(out)
|
59 |
+
img = self.decode_blocks(out)
|
60 |
+
|
61 |
+
return img
|
models/anime_gan_v3.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.utils import spectral_norm
|
5 |
+
from models.conv_blocks import DownConv
|
6 |
+
from models.conv_blocks import UpConv
|
7 |
+
from models.conv_blocks import SeparableConv2D
|
8 |
+
from models.conv_blocks import InvertedResBlock
|
9 |
+
from models.conv_blocks import ConvBlock
|
10 |
+
from utils.common import initialize_weights
|
11 |
+
|
12 |
+
|
13 |
+
class GeneratorV3(nn.Module):
|
14 |
+
pass
|
models/conv_blocks.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from utils.common import initialize_weights
|
4 |
+
from .layers import LayerNorm2d
|
5 |
+
|
6 |
+
|
7 |
+
class DownConv(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, channels, bias=False):
|
10 |
+
super(DownConv, self).__init__()
|
11 |
+
|
12 |
+
self.conv1 = SeparableConv2D(channels, channels, stride=2, bias=bias)
|
13 |
+
self.conv2 = SeparableConv2D(channels, channels, stride=1, bias=bias)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
out1 = self.conv1(x)
|
17 |
+
out2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')
|
18 |
+
out2 = self.conv2(out2)
|
19 |
+
|
20 |
+
return out1 + out2
|
21 |
+
|
22 |
+
|
23 |
+
class UpConv(nn.Module):
|
24 |
+
def __init__(self, channels, bias=False):
|
25 |
+
super(UpConv, self).__init__()
|
26 |
+
|
27 |
+
self.conv = SeparableConv2D(channels, channels, stride=1, bias=bias)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
out = F.interpolate(x, scale_factor=2.0, mode='bilinear')
|
31 |
+
out = self.conv(out)
|
32 |
+
return out
|
33 |
+
|
34 |
+
|
35 |
+
class UpConvLNormLReLU(nn.Module):
|
36 |
+
"""Upsample Conv block with Layer Norm and Leaky ReLU"""
|
37 |
+
def __init__(self, in_channels, out_channels, bias=False):
|
38 |
+
super(UpConvLNormLReLU, self).__init__()
|
39 |
+
|
40 |
+
self.conv_block = ConvBlock(
|
41 |
+
in_channels,
|
42 |
+
out_channels,
|
43 |
+
kernel_size=3,
|
44 |
+
bias=bias,
|
45 |
+
)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
out = F.interpolate(x, scale_factor=2.0, mode='bilinear')
|
49 |
+
out = self.conv_block(out)
|
50 |
+
return out
|
51 |
+
|
52 |
+
class SeparableConv2D(nn.Module):
|
53 |
+
def __init__(self, in_channels, out_channels, stride=1, bias=False):
|
54 |
+
super(SeparableConv2D, self).__init__()
|
55 |
+
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3,
|
56 |
+
stride=stride, padding=1, groups=in_channels, bias=bias)
|
57 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels,
|
58 |
+
kernel_size=1, stride=1, bias=bias)
|
59 |
+
# self.pad =
|
60 |
+
self.ins_norm1 = nn.InstanceNorm2d(in_channels)
|
61 |
+
self.activation1 = nn.LeakyReLU(0.2, True)
|
62 |
+
self.ins_norm2 = nn.InstanceNorm2d(out_channels)
|
63 |
+
self.activation2 = nn.LeakyReLU(0.2, True)
|
64 |
+
|
65 |
+
initialize_weights(self)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
out = self.depthwise(x)
|
69 |
+
out = self.ins_norm1(out)
|
70 |
+
out = self.activation1(out)
|
71 |
+
|
72 |
+
out = self.pointwise(out)
|
73 |
+
out = self.ins_norm2(out)
|
74 |
+
|
75 |
+
return self.activation2(out)
|
76 |
+
|
77 |
+
|
78 |
+
class ConvBlock(nn.Module):
|
79 |
+
"""Stack of Conv2D + Norm + LeakyReLU"""
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
channels,
|
83 |
+
out_channels,
|
84 |
+
kernel_size=3,
|
85 |
+
stride=1,
|
86 |
+
padding="valid",
|
87 |
+
bias=False,
|
88 |
+
norm_type="instance"
|
89 |
+
):
|
90 |
+
super(ConvBlock, self).__init__()
|
91 |
+
|
92 |
+
if kernel_size == 3 and stride == 1:
|
93 |
+
self.pad = nn.ReflectionPad2d((1, 1, 1, 1))
|
94 |
+
elif kernel_size == 7 and stride == 1:
|
95 |
+
self.pad = nn.ReflectionPad2d((3, 3, 3, 3))
|
96 |
+
elif stride == 2:
|
97 |
+
self.pad = nn.ReflectionPad2d((0, 1, 1, 0))
|
98 |
+
else:
|
99 |
+
self.pad = None
|
100 |
+
|
101 |
+
self.conv = nn.Conv2d(
|
102 |
+
channels,
|
103 |
+
out_channels,
|
104 |
+
kernel_size=kernel_size,
|
105 |
+
stride=stride,
|
106 |
+
padding=padding,
|
107 |
+
bias=bias
|
108 |
+
)
|
109 |
+
if norm_type == "instance":
|
110 |
+
self.ins_norm = nn.InstanceNorm2d(out_channels)
|
111 |
+
elif norm_type == "layer":
|
112 |
+
self.ins_norm = LayerNorm2d(out_channels)
|
113 |
+
self.activation = nn.LeakyReLU(0.2, True)
|
114 |
+
|
115 |
+
initialize_weights(self)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
if self.pad is not None:
|
119 |
+
x = self.pad(x)
|
120 |
+
out = self.conv(x)
|
121 |
+
out = self.ins_norm(out)
|
122 |
+
out = self.activation(out)
|
123 |
+
return out
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
class InvertedResBlock(nn.Module):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
channels=256,
|
131 |
+
out_channels=256,
|
132 |
+
expand_ratio=2,
|
133 |
+
bias=False,
|
134 |
+
norm_type="instance",
|
135 |
+
):
|
136 |
+
super(InvertedResBlock, self).__init__()
|
137 |
+
bottleneck_dim = round(expand_ratio * channels)
|
138 |
+
self.conv_block = ConvBlock(
|
139 |
+
channels,
|
140 |
+
bottleneck_dim,
|
141 |
+
kernel_size=1,
|
142 |
+
stride=1,
|
143 |
+
padding=0,
|
144 |
+
bias=bias
|
145 |
+
)
|
146 |
+
self.depthwise_conv = nn.Conv2d(
|
147 |
+
bottleneck_dim,
|
148 |
+
bottleneck_dim,
|
149 |
+
kernel_size=3,
|
150 |
+
groups=bottleneck_dim,
|
151 |
+
stride=1,
|
152 |
+
padding=1,
|
153 |
+
bias=bias
|
154 |
+
)
|
155 |
+
self.conv = nn.Conv2d(
|
156 |
+
bottleneck_dim,
|
157 |
+
out_channels,
|
158 |
+
kernel_size=1,
|
159 |
+
stride=1,
|
160 |
+
bias=bias
|
161 |
+
)
|
162 |
+
|
163 |
+
if norm_type == "instance":
|
164 |
+
self.ins_norm1 = nn.InstanceNorm2d(out_channels)
|
165 |
+
self.ins_norm2 = nn.InstanceNorm2d(out_channels)
|
166 |
+
elif norm_type == "layer":
|
167 |
+
# Keep var name as is for v1 compatibility.
|
168 |
+
self.ins_norm1 = LayerNorm2d(bottleneck_dim)
|
169 |
+
self.ins_norm2 = LayerNorm2d(out_channels)
|
170 |
+
self.activation = nn.LeakyReLU(0.2, True)
|
171 |
+
|
172 |
+
initialize_weights(self)
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
out = self.conv_block(x)
|
176 |
+
out = self.depthwise_conv(out)
|
177 |
+
out = self.ins_norm1(out)
|
178 |
+
out = self.activation(out)
|
179 |
+
out = self.conv(out)
|
180 |
+
out = self.ins_norm2(out)
|
181 |
+
|
182 |
+
if out.shape[1] != x.shape[1]:
|
183 |
+
# Only concate if same shape
|
184 |
+
return out
|
185 |
+
return out + x
|
models/layers.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class LayerNorm2d(nn.LayerNorm):
|
7 |
+
""" LayerNorm for channels of '2D' spatial NCHW tensors """
|
8 |
+
def __init__(self, num_channels, eps=1e-6, affine=True):
|
9 |
+
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
10 |
+
|
11 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
12 |
+
x = x.permute(0, 2, 3, 1)
|
13 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
14 |
+
x = x.permute(0, 3, 1, 2)
|
15 |
+
return x
|
16 |
+
|
17 |
+
|
18 |
+
def get_norm(norm_type):
|
19 |
+
if norm_type == "instance":
|
20 |
+
return nn.InstanceNorm2d
|
21 |
+
elif norm_type == "layer":
|
22 |
+
return LayerNorm2d
|
23 |
+
else:
|
24 |
+
raise ValueError(norm_type)
|
models/vgg.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy.lib.arraysetops import isin
|
2 |
+
import torchvision.models as models
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class Vgg19(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super(Vgg19, self).__init__()
|
11 |
+
self.vgg19 = self.get_vgg19().eval()
|
12 |
+
vgg_mean = torch.tensor([0.485, 0.456, 0.406]).float()
|
13 |
+
vgg_std = torch.tensor([0.229, 0.224, 0.225]).float()
|
14 |
+
self.mean = vgg_mean.view(-1, 1 ,1)
|
15 |
+
self.std = vgg_std.view(-1, 1, 1)
|
16 |
+
|
17 |
+
def to(self, device):
|
18 |
+
new_self = super(Vgg19, self).to(device)
|
19 |
+
new_self.mean = new_self.mean.to(device)
|
20 |
+
new_self.std = new_self.std.to(device)
|
21 |
+
return new_self
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
return self.vgg19(self.normalize_vgg(x))
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def get_vgg19(last_layer='conv4_4'):
|
28 |
+
vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
|
29 |
+
model_list = []
|
30 |
+
|
31 |
+
i = 0
|
32 |
+
j = 1
|
33 |
+
for layer in vgg.children():
|
34 |
+
if isinstance(layer, nn.MaxPool2d):
|
35 |
+
i = 0
|
36 |
+
j += 1
|
37 |
+
|
38 |
+
elif isinstance(layer, nn.Conv2d):
|
39 |
+
i += 1
|
40 |
+
|
41 |
+
name = f'conv{j}_{i}'
|
42 |
+
|
43 |
+
if name == last_layer:
|
44 |
+
model_list.append(layer)
|
45 |
+
break
|
46 |
+
|
47 |
+
model_list.append(layer)
|
48 |
+
|
49 |
+
|
50 |
+
model = nn.Sequential(*model_list)
|
51 |
+
return model
|
52 |
+
|
53 |
+
|
54 |
+
def normalize_vgg(self, image):
|
55 |
+
'''
|
56 |
+
Expect input in range -1 1
|
57 |
+
'''
|
58 |
+
image = (image + 1.0) / 2.0
|
59 |
+
return (image - self.mean) / self.std
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
from PIL import Image
|
64 |
+
import numpy as np
|
65 |
+
from utils.image_processing import normalize_input
|
66 |
+
|
67 |
+
image = Image.open("example/10.jpg")
|
68 |
+
image = image.resize((224, 224))
|
69 |
+
np_img = np.array(image).astype('float32')
|
70 |
+
np_img = normalize_input(np_img)
|
71 |
+
|
72 |
+
img = torch.from_numpy(np_img)
|
73 |
+
img = img.permute(2, 0, 1)
|
74 |
+
img = img.unsqueeze(0)
|
75 |
+
|
76 |
+
vgg = Vgg19()
|
77 |
+
|
78 |
+
feat = vgg(img)
|
79 |
+
|
80 |
+
print(feat.shape)
|