pytorch / pages /23_gan.py
eaglelandsonce's picture
Create 23_gan.py
fb402a1 verified
raw
history blame
3.74 kB
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
# Load and Preprocess the MNIST Dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=128, shuffle=True)
# Define the Generator and Discriminator Networks
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 28*28),
nn.Tanh()
)
def forward(self, input):
return self.main(input).view(-1, 1, 28, 28)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(28*28, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input.view(-1, 28*28))
# Initialize Models, Optimizers, and Loss Function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = Generator().to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Train the GAN
num_epochs = 50
fixed_noise = torch.randn(64, 100, device=device)
for epoch in range(num_epochs):
for i, (data, _) in enumerate(dataloader):
# Train Discriminator
netD.zero_grad()
real_data = data.to(device)
b_size = real_data.size(0)
label = torch.full((b_size,), 1., dtype=torch.float, device=device)
output = netD(real_data).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
noise = torch.randn(b_size, 100, device=device)
fake_data = netG(noise)
label.fill_(0.)
output = netD(fake_data.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
optimizerD.step()
# Train Generator
netG.zero_grad()
label.fill_(1.)
output = netD(fake_data).view(-1)
errG = criterion(output, label)
errG.backward()
optimizerG.step()
print(f'Epoch [{epoch+1}/{num_epochs}] Loss_D: {errD_real.item()+errD_fake.item()} Loss_G: {errG.item()}')
if epoch % 10 == 0:
with torch.no_grad():
fake_images = netG(fixed_noise).detach().cpu()
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.title(f"Generated Images at Epoch {epoch}")
plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0)))
plt.show()
# Generate and Visualize Synthetic Images
with torch.no_grad():
noise = torch.randn(64, 100, device=device)
fake_images = netG(noise).detach().cpu()
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.title("Generated Images")
plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0)))
plt.show()