57894-Pix2Pix / models /discriminator.py
Muhammad Naufal Rizqullah
first commit
ae0af75
import torch
import torch.nn as nn
from models.base import BlockCNN
class Discriminator(nn.Module):
def __init__(self, in_channels=3, features=[64, 128, 256, 512], kernel_size=4, activation_slope=0.2, ):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(
in_channels * 2,
features[0],
kernel_size,
stride=2,
padding=1,
padding_mode="reflect",
),
nn.LeakyReLU(activation_slope),
)
layers = []
in_channels = features[0]
for feature in features[1:]:
layers.append(
BlockCNN(in_channels, feature, stride=1 if feature == features[-1] else 2)
)
in_channels = feature
layers.append(
nn.Conv2d(
in_channels, 1, kernel_size=kernel_size, stride=1, padding=1, padding_mode="reflect"
)
)
self.model = nn.Sequential(*layers)
def forward(self, x, y):
x = torch.cat([x, y], dim=1)
x = self.initial(x)
return self.model(x)
def test():
# Test Case for Discriminator Model
x = torch.randn((1, 3, 256, 256))
disc = Discriminator()
print(f"Discriminator Output Shape: {disc(x, x).shape}")