culture commited on
Commit
818ce2d
·
1 Parent(s): 39f1cda

Upload gfpgan/archs/stylegan2_clean_arch.py

Browse files
Files changed (1) hide show
  1. gfpgan/archs/stylegan2_clean_arch.py +368 -0
gfpgan/archs/stylegan2_clean_arch.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from basicsr.archs.arch_util import default_init_weights
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class NormStyleCode(nn.Module):
11
+
12
+ def forward(self, x):
13
+ """Normalize the style codes.
14
+
15
+ Args:
16
+ x (Tensor): Style codes with shape (b, c).
17
+
18
+ Returns:
19
+ Tensor: Normalized tensor.
20
+ """
21
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
22
+
23
+
24
+ class ModulatedConv2d(nn.Module):
25
+ """Modulated Conv2d used in StyleGAN2.
26
+
27
+ There is no bias in ModulatedConv2d.
28
+
29
+ Args:
30
+ in_channels (int): Channel number of the input.
31
+ out_channels (int): Channel number of the output.
32
+ kernel_size (int): Size of the convolving kernel.
33
+ num_style_feat (int): Channel number of style features.
34
+ demodulate (bool): Whether to demodulate in the conv layer. Default: True.
35
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
36
+ eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
37
+ """
38
+
39
+ def __init__(self,
40
+ in_channels,
41
+ out_channels,
42
+ kernel_size,
43
+ num_style_feat,
44
+ demodulate=True,
45
+ sample_mode=None,
46
+ eps=1e-8):
47
+ super(ModulatedConv2d, self).__init__()
48
+ self.in_channels = in_channels
49
+ self.out_channels = out_channels
50
+ self.kernel_size = kernel_size
51
+ self.demodulate = demodulate
52
+ self.sample_mode = sample_mode
53
+ self.eps = eps
54
+
55
+ # modulation inside each modulated conv
56
+ self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
57
+ # initialization
58
+ default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
59
+
60
+ self.weight = nn.Parameter(
61
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
62
+ math.sqrt(in_channels * kernel_size**2))
63
+ self.padding = kernel_size // 2
64
+
65
+ def forward(self, x, style):
66
+ """Forward function.
67
+
68
+ Args:
69
+ x (Tensor): Tensor with shape (b, c, h, w).
70
+ style (Tensor): Tensor with shape (b, num_style_feat).
71
+
72
+ Returns:
73
+ Tensor: Modulated tensor after convolution.
74
+ """
75
+ b, c, h, w = x.shape # c = c_in
76
+ # weight modulation
77
+ style = self.modulation(style).view(b, 1, c, 1, 1)
78
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
79
+ weight = self.weight * style # (b, c_out, c_in, k, k)
80
+
81
+ if self.demodulate:
82
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
83
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
84
+
85
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
86
+
87
+ # upsample or downsample if necessary
88
+ if self.sample_mode == 'upsample':
89
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
90
+ elif self.sample_mode == 'downsample':
91
+ x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
92
+
93
+ b, c, h, w = x.shape
94
+ x = x.view(1, b * c, h, w)
95
+ # weight: (b*c_out, c_in, k, k), groups=b
96
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
97
+ out = out.view(b, self.out_channels, *out.shape[2:4])
98
+
99
+ return out
100
+
101
+ def __repr__(self):
102
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
103
+ f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
104
+
105
+
106
+ class StyleConv(nn.Module):
107
+ """Style conv used in StyleGAN2.
108
+
109
+ Args:
110
+ in_channels (int): Channel number of the input.
111
+ out_channels (int): Channel number of the output.
112
+ kernel_size (int): Size of the convolving kernel.
113
+ num_style_feat (int): Channel number of style features.
114
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
115
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
116
+ """
117
+
118
+ def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
119
+ super(StyleConv, self).__init__()
120
+ self.modulated_conv = ModulatedConv2d(
121
+ in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
122
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
123
+ self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
124
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
125
+
126
+ def forward(self, x, style, noise=None):
127
+ # modulate
128
+ out = self.modulated_conv(x, style) * 2**0.5 # for conversion
129
+ # noise injection
130
+ if noise is None:
131
+ b, _, h, w = out.shape
132
+ noise = out.new_empty(b, 1, h, w).normal_()
133
+ out = out + self.weight * noise
134
+ # add bias
135
+ out = out + self.bias
136
+ # activation
137
+ out = self.activate(out)
138
+ return out
139
+
140
+
141
+ class ToRGB(nn.Module):
142
+ """To RGB (image space) from features.
143
+
144
+ Args:
145
+ in_channels (int): Channel number of input.
146
+ num_style_feat (int): Channel number of style features.
147
+ upsample (bool): Whether to upsample. Default: True.
148
+ """
149
+
150
+ def __init__(self, in_channels, num_style_feat, upsample=True):
151
+ super(ToRGB, self).__init__()
152
+ self.upsample = upsample
153
+ self.modulated_conv = ModulatedConv2d(
154
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
155
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
156
+
157
+ def forward(self, x, style, skip=None):
158
+ """Forward function.
159
+
160
+ Args:
161
+ x (Tensor): Feature tensor with shape (b, c, h, w).
162
+ style (Tensor): Tensor with shape (b, num_style_feat).
163
+ skip (Tensor): Base/skip tensor. Default: None.
164
+
165
+ Returns:
166
+ Tensor: RGB images.
167
+ """
168
+ out = self.modulated_conv(x, style)
169
+ out = out + self.bias
170
+ if skip is not None:
171
+ if self.upsample:
172
+ skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
173
+ out = out + skip
174
+ return out
175
+
176
+
177
+ class ConstantInput(nn.Module):
178
+ """Constant input.
179
+
180
+ Args:
181
+ num_channel (int): Channel number of constant input.
182
+ size (int): Spatial size of constant input.
183
+ """
184
+
185
+ def __init__(self, num_channel, size):
186
+ super(ConstantInput, self).__init__()
187
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
188
+
189
+ def forward(self, batch):
190
+ out = self.weight.repeat(batch, 1, 1, 1)
191
+ return out
192
+
193
+
194
+ @ARCH_REGISTRY.register()
195
+ class StyleGAN2GeneratorClean(nn.Module):
196
+ """Clean version of StyleGAN2 Generator.
197
+
198
+ Args:
199
+ out_size (int): The spatial size of outputs.
200
+ num_style_feat (int): Channel number of style features. Default: 512.
201
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
202
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
203
+ narrow (float): Narrow ratio for channels. Default: 1.0.
204
+ """
205
+
206
+ def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
207
+ super(StyleGAN2GeneratorClean, self).__init__()
208
+ # Style MLP layers
209
+ self.num_style_feat = num_style_feat
210
+ style_mlp_layers = [NormStyleCode()]
211
+ for i in range(num_mlp):
212
+ style_mlp_layers.extend(
213
+ [nn.Linear(num_style_feat, num_style_feat, bias=True),
214
+ nn.LeakyReLU(negative_slope=0.2, inplace=True)])
215
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
216
+ # initialization
217
+ default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
218
+
219
+ # channel list
220
+ channels = {
221
+ '4': int(512 * narrow),
222
+ '8': int(512 * narrow),
223
+ '16': int(512 * narrow),
224
+ '32': int(512 * narrow),
225
+ '64': int(256 * channel_multiplier * narrow),
226
+ '128': int(128 * channel_multiplier * narrow),
227
+ '256': int(64 * channel_multiplier * narrow),
228
+ '512': int(32 * channel_multiplier * narrow),
229
+ '1024': int(16 * channel_multiplier * narrow)
230
+ }
231
+ self.channels = channels
232
+
233
+ self.constant_input = ConstantInput(channels['4'], size=4)
234
+ self.style_conv1 = StyleConv(
235
+ channels['4'],
236
+ channels['4'],
237
+ kernel_size=3,
238
+ num_style_feat=num_style_feat,
239
+ demodulate=True,
240
+ sample_mode=None)
241
+ self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
242
+
243
+ self.log_size = int(math.log(out_size, 2))
244
+ self.num_layers = (self.log_size - 2) * 2 + 1
245
+ self.num_latent = self.log_size * 2 - 2
246
+
247
+ self.style_convs = nn.ModuleList()
248
+ self.to_rgbs = nn.ModuleList()
249
+ self.noises = nn.Module()
250
+
251
+ in_channels = channels['4']
252
+ # noise
253
+ for layer_idx in range(self.num_layers):
254
+ resolution = 2**((layer_idx + 5) // 2)
255
+ shape = [1, 1, resolution, resolution]
256
+ self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
257
+ # style convs and to_rgbs
258
+ for i in range(3, self.log_size + 1):
259
+ out_channels = channels[f'{2**i}']
260
+ self.style_convs.append(
261
+ StyleConv(
262
+ in_channels,
263
+ out_channels,
264
+ kernel_size=3,
265
+ num_style_feat=num_style_feat,
266
+ demodulate=True,
267
+ sample_mode='upsample'))
268
+ self.style_convs.append(
269
+ StyleConv(
270
+ out_channels,
271
+ out_channels,
272
+ kernel_size=3,
273
+ num_style_feat=num_style_feat,
274
+ demodulate=True,
275
+ sample_mode=None))
276
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
277
+ in_channels = out_channels
278
+
279
+ def make_noise(self):
280
+ """Make noise for noise injection."""
281
+ device = self.constant_input.weight.device
282
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
283
+
284
+ for i in range(3, self.log_size + 1):
285
+ for _ in range(2):
286
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
287
+
288
+ return noises
289
+
290
+ def get_latent(self, x):
291
+ return self.style_mlp(x)
292
+
293
+ def mean_latent(self, num_latent):
294
+ latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
295
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
296
+ return latent
297
+
298
+ def forward(self,
299
+ styles,
300
+ input_is_latent=False,
301
+ noise=None,
302
+ randomize_noise=True,
303
+ truncation=1,
304
+ truncation_latent=None,
305
+ inject_index=None,
306
+ return_latents=False):
307
+ """Forward function for StyleGAN2GeneratorClean.
308
+
309
+ Args:
310
+ styles (list[Tensor]): Sample codes of styles.
311
+ input_is_latent (bool): Whether input is latent style. Default: False.
312
+ noise (Tensor | None): Input noise or None. Default: None.
313
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
314
+ truncation (float): The truncation ratio. Default: 1.
315
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
316
+ inject_index (int | None): The injection index for mixing noise. Default: None.
317
+ return_latents (bool): Whether to return style latents. Default: False.
318
+ """
319
+ # style codes -> latents with Style MLP layer
320
+ if not input_is_latent:
321
+ styles = [self.style_mlp(s) for s in styles]
322
+ # noises
323
+ if noise is None:
324
+ if randomize_noise:
325
+ noise = [None] * self.num_layers # for each style conv layer
326
+ else: # use the stored noise
327
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
328
+ # style truncation
329
+ if truncation < 1:
330
+ style_truncation = []
331
+ for style in styles:
332
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
333
+ styles = style_truncation
334
+ # get style latents with injection
335
+ if len(styles) == 1:
336
+ inject_index = self.num_latent
337
+
338
+ if styles[0].ndim < 3:
339
+ # repeat latent code for all the layers
340
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
341
+ else: # used for encoder with different latent code for each layer
342
+ latent = styles[0]
343
+ elif len(styles) == 2: # mixing noises
344
+ if inject_index is None:
345
+ inject_index = random.randint(1, self.num_latent - 1)
346
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
347
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
348
+ latent = torch.cat([latent1, latent2], 1)
349
+
350
+ # main generation
351
+ out = self.constant_input(latent.shape[0])
352
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
353
+ skip = self.to_rgb1(out, latent[:, 1])
354
+
355
+ i = 1
356
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
357
+ noise[2::2], self.to_rgbs):
358
+ out = conv1(out, latent[:, i], noise=noise1)
359
+ out = conv2(out, latent[:, i + 1], noise=noise2)
360
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
361
+ i += 2
362
+
363
+ image = skip
364
+
365
+ if return_latents:
366
+ return image, latent
367
+ else:
368
+ return image, None