Spaces:
Running
Running
File size: 3,740 Bytes
fb402a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
# Load and Preprocess the MNIST Dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=128, shuffle=True)
# Define the Generator and Discriminator Networks
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 28*28),
nn.Tanh()
)
def forward(self, input):
return self.main(input).view(-1, 1, 28, 28)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(28*28, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input.view(-1, 28*28))
# Initialize Models, Optimizers, and Loss Function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = Generator().to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Train the GAN
num_epochs = 50
fixed_noise = torch.randn(64, 100, device=device)
for epoch in range(num_epochs):
for i, (data, _) in enumerate(dataloader):
# Train Discriminator
netD.zero_grad()
real_data = data.to(device)
b_size = real_data.size(0)
label = torch.full((b_size,), 1., dtype=torch.float, device=device)
output = netD(real_data).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
noise = torch.randn(b_size, 100, device=device)
fake_data = netG(noise)
label.fill_(0.)
output = netD(fake_data.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
optimizerD.step()
# Train Generator
netG.zero_grad()
label.fill_(1.)
output = netD(fake_data).view(-1)
errG = criterion(output, label)
errG.backward()
optimizerG.step()
print(f'Epoch [{epoch+1}/{num_epochs}] Loss_D: {errD_real.item()+errD_fake.item()} Loss_G: {errG.item()}')
if epoch % 10 == 0:
with torch.no_grad():
fake_images = netG(fixed_noise).detach().cpu()
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.title(f"Generated Images at Epoch {epoch}")
plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0)))
plt.show()
# Generate and Visualize Synthetic Images
with torch.no_grad():
noise = torch.randn(64, 100, device=device)
fake_images = netG(noise).detach().cpu()
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.title("Generated Images")
plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0)))
plt.show()
|