eaglelandsonce commited on
Commit
fb402a1
·
verified ·
1 Parent(s): 5a1cec5

Create 23_gan.py

Browse files
Files changed (1) hide show
  1. pages/23_gan.py +118 -0
pages/23_gan.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision.transforms as transforms
5
+ import torchvision.datasets as datasets
6
+ import torchvision.utils as vutils
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+
10
+ # Load and Preprocess the MNIST Dataset
11
+ transform = transforms.Compose([
12
+ transforms.ToTensor(),
13
+ transforms.Normalize((0.5,), (0.5,))
14
+ ])
15
+
16
+ mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
17
+ dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=128, shuffle=True)
18
+
19
+ # Define the Generator and Discriminator Networks
20
+ class Generator(nn.Module):
21
+ def __init__(self):
22
+ super(Generator, self).__init__()
23
+ self.main = nn.Sequential(
24
+ nn.Linear(100, 256),
25
+ nn.ReLU(True),
26
+ nn.Linear(256, 512),
27
+ nn.ReLU(True),
28
+ nn.Linear(512, 1024),
29
+ nn.ReLU(True),
30
+ nn.Linear(1024, 28*28),
31
+ nn.Tanh()
32
+ )
33
+
34
+ def forward(self, input):
35
+ return self.main(input).view(-1, 1, 28, 28)
36
+
37
+ class Discriminator(nn.Module):
38
+ def __init__(self):
39
+ super(Discriminator, self).__init__()
40
+ self.main = nn.Sequential(
41
+ nn.Linear(28*28, 1024),
42
+ nn.LeakyReLU(0.2, inplace=True),
43
+ nn.Linear(1024, 512),
44
+ nn.LeakyReLU(0.2, inplace=True),
45
+ nn.Linear(512, 256),
46
+ nn.LeakyReLU(0.2, inplace=True),
47
+ nn.Linear(256, 1),
48
+ nn.Sigmoid()
49
+ )
50
+
51
+ def forward(self, input):
52
+ return self.main(input.view(-1, 28*28))
53
+
54
+ # Initialize Models, Optimizers, and Loss Function
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+
57
+ netG = Generator().to(device)
58
+ netD = Discriminator().to(device)
59
+
60
+ criterion = nn.BCELoss()
61
+
62
+ optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
63
+ optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
64
+
65
+ # Train the GAN
66
+ num_epochs = 50
67
+ fixed_noise = torch.randn(64, 100, device=device)
68
+
69
+ for epoch in range(num_epochs):
70
+ for i, (data, _) in enumerate(dataloader):
71
+ # Train Discriminator
72
+ netD.zero_grad()
73
+ real_data = data.to(device)
74
+ b_size = real_data.size(0)
75
+ label = torch.full((b_size,), 1., dtype=torch.float, device=device)
76
+
77
+ output = netD(real_data).view(-1)
78
+ errD_real = criterion(output, label)
79
+ errD_real.backward()
80
+
81
+ noise = torch.randn(b_size, 100, device=device)
82
+ fake_data = netG(noise)
83
+ label.fill_(0.)
84
+
85
+ output = netD(fake_data.detach()).view(-1)
86
+ errD_fake = criterion(output, label)
87
+ errD_fake.backward()
88
+ optimizerD.step()
89
+
90
+ # Train Generator
91
+ netG.zero_grad()
92
+ label.fill_(1.)
93
+ output = netD(fake_data).view(-1)
94
+ errG = criterion(output, label)
95
+ errG.backward()
96
+ optimizerG.step()
97
+
98
+ print(f'Epoch [{epoch+1}/{num_epochs}] Loss_D: {errD_real.item()+errD_fake.item()} Loss_G: {errG.item()}')
99
+
100
+ if epoch % 10 == 0:
101
+ with torch.no_grad():
102
+ fake_images = netG(fixed_noise).detach().cpu()
103
+ plt.figure(figsize=(10, 10))
104
+ plt.axis("off")
105
+ plt.title(f"Generated Images at Epoch {epoch}")
106
+ plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0)))
107
+ plt.show()
108
+
109
+ # Generate and Visualize Synthetic Images
110
+ with torch.no_grad():
111
+ noise = torch.randn(64, 100, device=device)
112
+ fake_images = netG(noise).detach().cpu()
113
+
114
+ plt.figure(figsize=(10, 10))
115
+ plt.axis("off")
116
+ plt.title("Generated Images")
117
+ plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0)))
118
+ plt.show()