Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.transforms as transforms | |
import torchvision.datasets as datasets | |
import torchvision.utils as vutils | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# Load and Preprocess the MNIST Dataset | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) | |
dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=128, shuffle=True) | |
# Define the Generator and Discriminator Networks | |
class Generator(nn.Module): | |
def __init__(self): | |
super(Generator, self).__init__() | |
self.main = nn.Sequential( | |
nn.Linear(100, 256), | |
nn.ReLU(True), | |
nn.Linear(256, 512), | |
nn.ReLU(True), | |
nn.Linear(512, 1024), | |
nn.ReLU(True), | |
nn.Linear(1024, 28*28), | |
nn.Tanh() | |
) | |
def forward(self, input): | |
return self.main(input).view(-1, 1, 28, 28) | |
class Discriminator(nn.Module): | |
def __init__(self): | |
super(Discriminator, self).__init__() | |
self.main = nn.Sequential( | |
nn.Linear(28*28, 1024), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Linear(1024, 512), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Linear(512, 256), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Linear(256, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, input): | |
return self.main(input.view(-1, 28*28)) | |
# Initialize Models, Optimizers, and Loss Function | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
netG = Generator().to(device) | |
netD = Discriminator().to(device) | |
criterion = nn.BCELoss() | |
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
# Train the GAN | |
num_epochs = 50 | |
fixed_noise = torch.randn(64, 100, device=device) | |
for epoch in range(num_epochs): | |
for i, (data, _) in enumerate(dataloader): | |
# Train Discriminator | |
netD.zero_grad() | |
real_data = data.to(device) | |
b_size = real_data.size(0) | |
label = torch.full((b_size,), 1., dtype=torch.float, device=device) | |
output = netD(real_data).view(-1) | |
errD_real = criterion(output, label) | |
errD_real.backward() | |
noise = torch.randn(b_size, 100, device=device) | |
fake_data = netG(noise) | |
label.fill_(0.) | |
output = netD(fake_data.detach()).view(-1) | |
errD_fake = criterion(output, label) | |
errD_fake.backward() | |
optimizerD.step() | |
# Train Generator | |
netG.zero_grad() | |
label.fill_(1.) | |
output = netD(fake_data).view(-1) | |
errG = criterion(output, label) | |
errG.backward() | |
optimizerG.step() | |
print(f'Epoch [{epoch+1}/{num_epochs}] Loss_D: {errD_real.item()+errD_fake.item()} Loss_G: {errG.item()}') | |
if epoch % 10 == 0: | |
with torch.no_grad(): | |
fake_images = netG(fixed_noise).detach().cpu() | |
plt.figure(figsize=(10, 10)) | |
plt.axis("off") | |
plt.title(f"Generated Images at Epoch {epoch}") | |
plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0))) | |
plt.show() | |
# Generate and Visualize Synthetic Images | |
with torch.no_grad(): | |
noise = torch.randn(64, 100, device=device) | |
fake_images = netG(noise).detach().cpu() | |
plt.figure(figsize=(10, 10)) | |
plt.axis("off") | |
plt.title("Generated Images") | |
plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0))) | |
plt.show() | |