Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
import streamlit as st | |
# Define the Generator | |
class Generator(nn.Module): | |
def __init__(self): | |
super(Generator, self).__init__() | |
self.main = nn.Sequential( | |
nn.ConvTranspose2d(100, 256, 4, 1, 0, bias=False), | |
nn.BatchNorm2d(256), | |
nn.ReLU(True), | |
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(128), | |
nn.ReLU(True), | |
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(64), | |
nn.ReLU(True), | |
nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False), | |
nn.Tanh() | |
) | |
def forward(self, input): | |
return self.main(input) | |
# Define the Discriminator | |
class Discriminator(nn.Module): | |
def __init__(self): | |
super(Discriminator, self).__init__() | |
self.main = nn.Sequential( | |
nn.Conv2d(1, 64, 4, 2, 1, bias=False), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(64, 128, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(128), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(128, 256, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(256), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(256, 1, 4, 1, 0, bias=False), | |
nn.Sigmoid() | |
) | |
def forward(self, input): | |
return self.main(input) | |
# Initialize the models | |
netG = Generator() | |
netD = Discriminator() | |
# Loss function | |
criterion = nn.BCELoss() | |
# Optimizers | |
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
# Device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
netG.to(device) | |
netD.to(device) | |
criterion.to(device) | |
# Function to generate and save images | |
def generate_images(num_images, noise_dim): | |
netG.eval() | |
noise = torch.randn(num_images, noise_dim, 1, 1, device=device) | |
fake_images = netG(noise) | |
return fake_images | |
# Streamlit interface | |
st.title("Simple GAN with Streamlit") | |
st.write("Generate images using a simple GAN") | |
num_images = st.slider("Number of images to generate", min_value=1, max_value=64, value=8) | |
noise_dim = 100 | |
if st.button("Generate Images"): | |
with st.spinner("Generating images..."): | |
fake_images = generate_images(num_images, noise_dim) | |
grid = vutils.make_grid(fake_images.cpu(), padding=2, normalize=True) | |
st.image(grid.permute(1, 2, 0).numpy(), caption="Generated Images") | |