import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms import torchvision.utils as vutils import streamlit as st # Define the Generator class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.main = nn.Sequential( nn.ConvTranspose2d(100, 256, 4, 1, 0, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, input): return self.main(input) # Define the Discriminator class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main = nn.Sequential( nn.Conv2d(1, 64, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, input): return self.main(input) # Initialize the models netG = Generator() netD = Discriminator() # Loss function criterion = nn.BCELoss() # Optimizers optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") netG.to(device) netD.to(device) criterion.to(device) # Function to generate and save images def generate_images(num_images, noise_dim): netG.eval() noise = torch.randn(num_images, noise_dim, 1, 1, device=device) fake_images = netG(noise) return fake_images # Streamlit interface st.title("Simple GAN with Streamlit") st.write("Generate images using a simple GAN") num_images = st.slider("Number of images to generate", min_value=1, max_value=64, value=8) noise_dim = 100 if st.button("Generate Images"): with st.spinner("Generating images..."): fake_images = generate_images(num_images, noise_dim) grid = vutils.make_grid(fake_images.cpu(), padding=2, normalize=True) st.image(grid.permute(1, 2, 0).numpy(), caption="Generated Images")