eaglelandsonce commited on
Commit
1b6b756
·
verified ·
1 Parent(s): 27eb8e5

Rename pages/23_Gan.py to pages/23_GANs.py

Browse files
Files changed (2) hide show
  1. pages/23_GANs.py +114 -0
  2. pages/23_Gan.py +0 -84
pages/23_GANs.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision.transforms as transforms
5
+ import torchvision.datasets as datasets
6
+ from torch.utils.data import DataLoader
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import streamlit as st
10
+
11
+ # Define the Generator
12
+ class Generator(nn.Module):
13
+ def __init__(self, input_dim, output_dim):
14
+ super(Generator, self).__init__()
15
+ self.model = nn.Sequential(
16
+ nn.Linear(input_dim, 128),
17
+ nn.ReLU(),
18
+ nn.Linear(128, 256),
19
+ nn.ReLU(),
20
+ nn.Linear(256, output_dim),
21
+ nn.Tanh()
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.model(x)
26
+
27
+ # Define the Discriminator
28
+ class Discriminator(nn.Module):
29
+ def __init__(self, input_dim):
30
+ super(Discriminator, self).__init__()
31
+ self.model = nn.Sequential(
32
+ nn.Linear(input_dim, 256),
33
+ nn.LeakyReLU(0.2),
34
+ nn.Linear(256, 128),
35
+ nn.LeakyReLU(0.2),
36
+ nn.Linear(128, 1),
37
+ nn.Sigmoid()
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.model(x)
42
+
43
+ # Hyperparameters
44
+ latent_dim = 100
45
+ image_dim = 28 * 28 # MNIST images are 28x28 pixels
46
+ lr = 0.0002
47
+ batch_size = 64
48
+ epochs = 50
49
+
50
+ # Prepare the data
51
+ transform = transforms.Compose([
52
+ transforms.ToTensor(),
53
+ transforms.Normalize([0.5], [0.5])
54
+ ])
55
+
56
+ dataset = datasets.MNIST(root='mnist_data', train=True, transform=transform, download=True)
57
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
58
+
59
+ # Initialize the models
60
+ generator = Generator(latent_dim, image_dim)
61
+ discriminator = Discriminator(image_dim)
62
+
63
+ # Optimizers
64
+ optimizer_G = optim.Adam(generator.parameters(), lr=lr)
65
+ optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
66
+
67
+ # Loss function
68
+ criterion = nn.BCELoss()
69
+
70
+ # Streamlit interface
71
+ st.title("GAN with PyTorch and Hugging Face")
72
+ st.write("Training a GAN to generate MNIST digits")
73
+
74
+ train_gan = st.button("Train GAN")
75
+
76
+ if train_gan:
77
+ # Training loop
78
+ for epoch in range(epochs):
79
+ for i, (imgs, _) in enumerate(dataloader):
80
+ # Prepare real and fake data
81
+ real_imgs = imgs.view(imgs.size(0), -1)
82
+ real_labels = torch.ones(imgs.size(0), 1)
83
+ fake_labels = torch.zeros(imgs.size(0), 1)
84
+ z = torch.randn(imgs.size(0), latent_dim)
85
+ fake_imgs = generator(z)
86
+
87
+ # Train Discriminator
88
+ optimizer_D.zero_grad()
89
+ real_loss = criterion(discriminator(real_imgs), real_labels)
90
+ fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
91
+ d_loss = real_loss + fake_loss
92
+ d_loss.backward()
93
+ optimizer_D.step()
94
+
95
+ # Train Generator
96
+ optimizer_G.zero_grad()
97
+ g_loss = criterion(discriminator(fake_imgs), real_labels)
98
+ g_loss.backward()
99
+ optimizer_G.step()
100
+
101
+ st.write(f"Epoch [{epoch+1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
102
+
103
+ st.write("Training completed")
104
+
105
+ # Generate and display images
106
+ z = torch.randn(16, latent_dim)
107
+ generated_imgs = generator(z).view(-1, 1, 28, 28).data
108
+ grid = np.transpose(np.array([generated_imgs[i].numpy() for i in range(16)]), (1, 2, 0))
109
+
110
+ fig, ax = plt.subplots(figsize=(8, 8))
111
+ ax.imshow(np.squeeze(grid), cmap="gray")
112
+ st.pyplot(fig)
113
+ else:
114
+ st.write("Click the button to start training the GAN")
pages/23_Gan.py DELETED
@@ -1,84 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- import torchvision.transforms as transforms
5
- import torchvision.utils as vutils
6
- import streamlit as st
7
-
8
- # Define the Generator
9
- class Generator(nn.Module):
10
- def __init__(self):
11
- super(Generator, self).__init__()
12
- self.main = nn.Sequential(
13
- nn.ConvTranspose2d(100, 256, 4, 1, 0, bias=False),
14
- nn.BatchNorm2d(256),
15
- nn.ReLU(True),
16
- nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
17
- nn.BatchNorm2d(128),
18
- nn.ReLU(True),
19
- nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
20
- nn.BatchNorm2d(64),
21
- nn.ReLU(True),
22
- nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
23
- nn.Tanh()
24
- )
25
-
26
- def forward(self, input):
27
- return self.main(input)
28
-
29
- # Define the Discriminator
30
- class Discriminator(nn.Module):
31
- def __init__(self):
32
- super(Discriminator, self).__init__()
33
- self.main = nn.Sequential(
34
- nn.Conv2d(1, 64, 4, 2, 1, bias=False),
35
- nn.LeakyReLU(0.2, inplace=True),
36
- nn.Conv2d(64, 128, 4, 2, 1, bias=False),
37
- nn.BatchNorm2d(128),
38
- nn.LeakyReLU(0.2, inplace=True),
39
- nn.Conv2d(128, 256, 4, 2, 1, bias=False),
40
- nn.BatchNorm2d(256),
41
- nn.LeakyReLU(0.2, inplace=True),
42
- nn.Conv2d(256, 1, 4, 1, 0, bias=False),
43
- nn.Sigmoid()
44
- )
45
-
46
- def forward(self, input):
47
- return self.main(input)
48
-
49
- # Initialize the models
50
- netG = Generator()
51
- netD = Discriminator()
52
-
53
- # Loss function
54
- criterion = nn.BCELoss()
55
-
56
- # Optimizers
57
- optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
58
- optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
59
-
60
- # Device
61
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
- netG.to(device)
63
- netD.to(device)
64
- criterion.to(device)
65
-
66
- # Function to generate and save images
67
- def generate_images(num_images, noise_dim):
68
- netG.eval()
69
- noise = torch.randn(num_images, noise_dim, 1, 1, device=device)
70
- fake_images = netG(noise)
71
- return fake_images
72
-
73
- # Streamlit interface
74
- st.title("Simple GAN with Streamlit")
75
- st.write("Generate images using a simple GAN")
76
-
77
- num_images = st.slider("Number of images to generate", min_value=1, max_value=64, value=8)
78
- noise_dim = 100
79
-
80
- if st.button("Generate Images"):
81
- with st.spinner("Generating images..."):
82
- fake_images = generate_images(num_images, noise_dim)
83
- grid = vutils.make_grid(fake_images.cpu(), padding=2, normalize=True)
84
- st.image(grid.permute(1, 2, 0).numpy(), caption="Generated Images")