import torch.nn as nn class Block(nn.Module): def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect") if down else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2), ) self.use_dropout = use_dropout self.dropout = nn.Dropout(0.5) self.down = down def forward(self, x): x = self.conv(x) return self.dropout(x) class BlockCNN(nn.Module): def __init__(self, in_channels, out_channels, stride=2): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode="reflect"), nn.BatchNorm2d(out_channels), nn.LeakyReLU(0.2), ) def forward(self, x): return self.conv(x)