eaglelandsonce commited on
Commit
cbc79c2
·
verified ·
1 Parent(s): fb402a1

Rename pages/23_gan.py to pages/23_Gan.py

Browse files
Files changed (2) hide show
  1. pages/23_Gan.py +84 -0
  2. pages/23_gan.py +0 -118
pages/23_Gan.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision.transforms as transforms
5
+ import torchvision.utils as vutils
6
+ import streamlit as st
7
+
8
+ # Define the Generator
9
+ class Generator(nn.Module):
10
+ def __init__(self):
11
+ super(Generator, self).__init__()
12
+ self.main = nn.Sequential(
13
+ nn.ConvTranspose2d(100, 256, 4, 1, 0, bias=False),
14
+ nn.BatchNorm2d(256),
15
+ nn.ReLU(True),
16
+ nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
17
+ nn.BatchNorm2d(128),
18
+ nn.ReLU(True),
19
+ nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
20
+ nn.BatchNorm2d(64),
21
+ nn.ReLU(True),
22
+ nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
23
+ nn.Tanh()
24
+ )
25
+
26
+ def forward(self, input):
27
+ return self.main(input)
28
+
29
+ # Define the Discriminator
30
+ class Discriminator(nn.Module):
31
+ def __init__(self):
32
+ super(Discriminator, self).__init__()
33
+ self.main = nn.Sequential(
34
+ nn.Conv2d(1, 64, 4, 2, 1, bias=False),
35
+ nn.LeakyReLU(0.2, inplace=True),
36
+ nn.Conv2d(64, 128, 4, 2, 1, bias=False),
37
+ nn.BatchNorm2d(128),
38
+ nn.LeakyReLU(0.2, inplace=True),
39
+ nn.Conv2d(128, 256, 4, 2, 1, bias=False),
40
+ nn.BatchNorm2d(256),
41
+ nn.LeakyReLU(0.2, inplace=True),
42
+ nn.Conv2d(256, 1, 4, 1, 0, bias=False),
43
+ nn.Sigmoid()
44
+ )
45
+
46
+ def forward(self, input):
47
+ return self.main(input)
48
+
49
+ # Initialize the models
50
+ netG = Generator()
51
+ netD = Discriminator()
52
+
53
+ # Loss function
54
+ criterion = nn.BCELoss()
55
+
56
+ # Optimizers
57
+ optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
58
+ optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
59
+
60
+ # Device
61
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ netG.to(device)
63
+ netD.to(device)
64
+ criterion.to(device)
65
+
66
+ # Function to generate and save images
67
+ def generate_images(num_images, noise_dim):
68
+ netG.eval()
69
+ noise = torch.randn(num_images, noise_dim, 1, 1, device=device)
70
+ fake_images = netG(noise)
71
+ return fake_images
72
+
73
+ # Streamlit interface
74
+ st.title("Simple GAN with Streamlit")
75
+ st.write("Generate images using a simple GAN")
76
+
77
+ num_images = st.slider("Number of images to generate", min_value=1, max_value=64, value=8)
78
+ noise_dim = 100
79
+
80
+ if st.button("Generate Images"):
81
+ with st.spinner("Generating images..."):
82
+ fake_images = generate_images(num_images, noise_dim)
83
+ grid = vutils.make_grid(fake_images.cpu(), padding=2, normalize=True)
84
+ st.image(grid.permute(1, 2, 0).numpy(), caption="Generated Images")
pages/23_gan.py DELETED
@@ -1,118 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- import torchvision.transforms as transforms
5
- import torchvision.datasets as datasets
6
- import torchvision.utils as vutils
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
-
10
- # Load and Preprocess the MNIST Dataset
11
- transform = transforms.Compose([
12
- transforms.ToTensor(),
13
- transforms.Normalize((0.5,), (0.5,))
14
- ])
15
-
16
- mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
17
- dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=128, shuffle=True)
18
-
19
- # Define the Generator and Discriminator Networks
20
- class Generator(nn.Module):
21
- def __init__(self):
22
- super(Generator, self).__init__()
23
- self.main = nn.Sequential(
24
- nn.Linear(100, 256),
25
- nn.ReLU(True),
26
- nn.Linear(256, 512),
27
- nn.ReLU(True),
28
- nn.Linear(512, 1024),
29
- nn.ReLU(True),
30
- nn.Linear(1024, 28*28),
31
- nn.Tanh()
32
- )
33
-
34
- def forward(self, input):
35
- return self.main(input).view(-1, 1, 28, 28)
36
-
37
- class Discriminator(nn.Module):
38
- def __init__(self):
39
- super(Discriminator, self).__init__()
40
- self.main = nn.Sequential(
41
- nn.Linear(28*28, 1024),
42
- nn.LeakyReLU(0.2, inplace=True),
43
- nn.Linear(1024, 512),
44
- nn.LeakyReLU(0.2, inplace=True),
45
- nn.Linear(512, 256),
46
- nn.LeakyReLU(0.2, inplace=True),
47
- nn.Linear(256, 1),
48
- nn.Sigmoid()
49
- )
50
-
51
- def forward(self, input):
52
- return self.main(input.view(-1, 28*28))
53
-
54
- # Initialize Models, Optimizers, and Loss Function
55
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
-
57
- netG = Generator().to(device)
58
- netD = Discriminator().to(device)
59
-
60
- criterion = nn.BCELoss()
61
-
62
- optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
63
- optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
64
-
65
- # Train the GAN
66
- num_epochs = 50
67
- fixed_noise = torch.randn(64, 100, device=device)
68
-
69
- for epoch in range(num_epochs):
70
- for i, (data, _) in enumerate(dataloader):
71
- # Train Discriminator
72
- netD.zero_grad()
73
- real_data = data.to(device)
74
- b_size = real_data.size(0)
75
- label = torch.full((b_size,), 1., dtype=torch.float, device=device)
76
-
77
- output = netD(real_data).view(-1)
78
- errD_real = criterion(output, label)
79
- errD_real.backward()
80
-
81
- noise = torch.randn(b_size, 100, device=device)
82
- fake_data = netG(noise)
83
- label.fill_(0.)
84
-
85
- output = netD(fake_data.detach()).view(-1)
86
- errD_fake = criterion(output, label)
87
- errD_fake.backward()
88
- optimizerD.step()
89
-
90
- # Train Generator
91
- netG.zero_grad()
92
- label.fill_(1.)
93
- output = netD(fake_data).view(-1)
94
- errG = criterion(output, label)
95
- errG.backward()
96
- optimizerG.step()
97
-
98
- print(f'Epoch [{epoch+1}/{num_epochs}] Loss_D: {errD_real.item()+errD_fake.item()} Loss_G: {errG.item()}')
99
-
100
- if epoch % 10 == 0:
101
- with torch.no_grad():
102
- fake_images = netG(fixed_noise).detach().cpu()
103
- plt.figure(figsize=(10, 10))
104
- plt.axis("off")
105
- plt.title(f"Generated Images at Epoch {epoch}")
106
- plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0)))
107
- plt.show()
108
-
109
- # Generate and Visualize Synthetic Images
110
- with torch.no_grad():
111
- noise = torch.randn(64, 100, device=device)
112
- fake_images = netG(noise).detach().cpu()
113
-
114
- plt.figure(figsize=(10, 10))
115
- plt.axis("off")
116
- plt.title("Generated Images")
117
- plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=2, normalize=True), (1, 2, 0)))
118
- plt.show()