eaglelandsonce commited on
Commit
ec2b0f4
·
verified ·
1 Parent(s): d23f696

Create 26_GANS.py

Browse files
Files changed (1) hide show
  1. pages/26_GANS.py +113 -0
pages/26_GANS.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torchvision
6
+ import torchvision.transforms as transforms
7
+ from torchvision.utils import make_grid
8
+ import matplotlib.pyplot as plt
9
+
10
+ # Set device
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ # Hyperparameters
14
+ z_dim = 64
15
+ image_dim = 28 * 28
16
+ batch_size = 32
17
+ lr = 3e-4
18
+
19
+ # Load Data
20
+ transform = transforms.Compose([
21
+ transforms.ToTensor(),
22
+ transforms.Normalize((0.5,), (0.5,))
23
+ ])
24
+
25
+ dataset = torchvision.datasets.MNIST(root='dataset/', transform=transform, download=True)
26
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
27
+
28
+ # Generator
29
+ class Generator(nn.Module):
30
+ def __init__(self, z_dim, img_dim):
31
+ super().__init__()
32
+ self.gen = nn.Sequential(
33
+ nn.Linear(z_dim, 256),
34
+ nn.ReLU(),
35
+ nn.Linear(256, 512),
36
+ nn.ReLU(),
37
+ nn.Linear(512, 1024),
38
+ nn.ReLU(),
39
+ nn.Linear(1024, img_dim),
40
+ nn.Tanh()
41
+ )
42
+
43
+ def forward(self, x):
44
+ return self.gen(x)
45
+
46
+ # Discriminator
47
+ class Discriminator(nn.Module):
48
+ def __init__(self, img_dim):
49
+ super().__init__()
50
+ self.disc = nn.Sequential(
51
+ nn.Linear(img_dim, 1024),
52
+ nn.ReLU(),
53
+ nn.Linear(1024, 512),
54
+ nn.ReLU(),
55
+ nn.Linear(512, 256),
56
+ nn.ReLU(),
57
+ nn.Linear(256, 1),
58
+ nn.Sigmoid(),
59
+ )
60
+
61
+ def forward(self, x):
62
+ return self.disc(x)
63
+
64
+ # Initialize generator and discriminator
65
+ gen = Generator(z_dim, image_dim).to(device)
66
+ disc = Discriminator(image_dim).to(device)
67
+
68
+ # Optimizers
69
+ opt_gen = optim.Adam(gen.parameters(), lr=lr)
70
+ opt_disc = optim.Adam(disc.parameters(), lr=lr)
71
+
72
+ # Loss function
73
+ criterion = nn.BCELoss()
74
+
75
+ # Function to train the model
76
+ def train_gan(epochs):
77
+ for epoch in range(epochs):
78
+ for batch_idx, (real, _) in enumerate(dataloader):
79
+ real = real.view(-1, 784).to(device)
80
+ batch_size = real.shape[0]
81
+
82
+ # Train Discriminator
83
+ noise = torch.randn(batch_size, z_dim).to(device)
84
+ fake = gen(noise)
85
+ disc_real = disc(real).view(-1)
86
+ lossD_real = criterion(disc_real, torch.ones_like(disc_real))
87
+ disc_fake = disc(fake).view(-1)
88
+ lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
89
+ lossD = (lossD_real + lossD_fake) / 2
90
+ disc.zero_grad()
91
+ lossD.backward(retain_graph=True)
92
+ opt_disc.step()
93
+
94
+ # Train Generator
95
+ output = disc(fake).view(-1)
96
+ lossG = criterion(output, torch.ones_like(output))
97
+ gen.zero_grad()
98
+ lossG.backward()
99
+ opt_gen.step()
100
+
101
+ st.write(f"Epoch [{epoch+1}/{epochs}] Loss D: {lossD:.4f}, Loss G: {lossG:.4f}")
102
+
103
+ return fake
104
+
105
+ # Streamlit interface
106
+ st.title("Simple GAN with Epoch Slider")
107
+ epochs = st.slider("Number of Epochs", 1, 100, 1)
108
+ if st.button("Train GAN"):
109
+ fake_images = train_gan(epochs)
110
+ fake_images = fake_images.view(-1, 1, 28, 28)
111
+ fake_images = make_grid(fake_images, nrow=8, normalize=True)
112
+ plt.imshow(fake_images.permute(1, 2, 0).cpu().detach().numpy(), cmap='gray')
113
+ st.pyplot(plt.gcf())