Spaces:
Running
Running
File size: 3,655 Bytes
7e4e601 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from .conv_blocks import DownConv
from .conv_blocks import UpConv
from .conv_blocks import SeparableConv2D
from .conv_blocks import InvertedResBlock
from .conv_blocks import ConvBlock
from .layers import get_norm
from utils.common import initialize_weights
class GeneratorV1(nn.Module):
def __init__(self, dataset=''):
super(GeneratorV1, self).__init__()
self.name = f'{self.__class__.__name__}_{dataset}'
bias = False
self.encode_blocks = nn.Sequential(
ConvBlock(3, 64, bias=bias),
ConvBlock(64, 128, bias=bias),
DownConv(128, bias=bias),
ConvBlock(128, 128, bias=bias),
SeparableConv2D(128, 256, bias=bias),
DownConv(256, bias=bias),
ConvBlock(256, 256, bias=bias),
)
self.res_blocks = nn.Sequential(
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
InvertedResBlock(256, 256, bias=bias),
)
self.decode_blocks = nn.Sequential(
ConvBlock(256, 128, bias=bias),
UpConv(128, bias=bias),
SeparableConv2D(128, 128, bias=bias),
ConvBlock(128, 128, bias=bias),
UpConv(128, bias=bias),
ConvBlock(128, 64, bias=bias),
ConvBlock(64, 64, bias=bias),
nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, bias=bias),
nn.Tanh(),
)
initialize_weights(self)
def forward(self, x):
out = self.encode_blocks(x)
out = self.res_blocks(out)
img = self.decode_blocks(out)
return img
class Discriminator(nn.Module):
def __init__(
self,
dataset=None,
num_layers=1,
use_sn=False,
norm_type="instance",
):
super(Discriminator, self).__init__()
self.name = f'discriminator_{dataset}'
self.bias = False
channels = 32
layers = [
nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
nn.LeakyReLU(0.2, True)
]
in_channels = channels
for i in range(num_layers):
layers += [
nn.Conv2d(in_channels, channels * 2, kernel_size=3, stride=2, padding=1, bias=self.bias),
nn.LeakyReLU(0.2, True),
nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=self.bias),
get_norm(norm_type)(channels * 4),
nn.LeakyReLU(0.2, True),
]
in_channels = channels * 4
channels *= 2
channels *= 2
layers += [
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
get_norm(norm_type)(channels),
nn.LeakyReLU(0.2, True),
nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=self.bias),
]
if use_sn:
for i in range(len(layers)):
if isinstance(layers[i], nn.Conv2d):
layers[i] = spectral_norm(layers[i])
self.discriminate = nn.Sequential(*layers)
initialize_weights(self)
def forward(self, img):
logits = self.discriminate(img)
return logits
|