pytorch / pages /23_Gan.py
eaglelandsonce's picture
Rename pages/23_gan.py to pages/23_Gan.py
cbc79c2 verified
raw
history blame
2.65 kB
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
import streamlit as st
# Define the Generator
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 256, 4, 1, 0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
# Define the Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(1, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
# Initialize the models
netG = Generator()
netD = Discriminator()
# Loss function
criterion = nn.BCELoss()
# Optimizers
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG.to(device)
netD.to(device)
criterion.to(device)
# Function to generate and save images
def generate_images(num_images, noise_dim):
netG.eval()
noise = torch.randn(num_images, noise_dim, 1, 1, device=device)
fake_images = netG(noise)
return fake_images
# Streamlit interface
st.title("Simple GAN with Streamlit")
st.write("Generate images using a simple GAN")
num_images = st.slider("Number of images to generate", min_value=1, max_value=64, value=8)
noise_dim = 100
if st.button("Generate Images"):
with st.spinner("Generating images..."):
fake_images = generate_images(num_images, noise_dim)
grid = vutils.make_grid(fake_images.cpu(), padding=2, normalize=True)
st.image(grid.permute(1, 2, 0).numpy(), caption="Generated Images")