import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms import torchvision.datasets as datasets from torch.utils.data import DataLoader import numpy as np import matplotlib.pyplot as plt import streamlit as st # Define the Generator class Generator(nn.Module): def __init__(self, input_dim, output_dim): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(input_dim, 128), nn.ReLU(), nn.Linear(128, 256), nn.ReLU(), nn.Linear(256, output_dim), nn.Tanh() ) def forward(self, x): return self.model(x) # Define the Discriminator class Discriminator(nn.Module): def __init__(self, input_dim): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(input_dim, 256), nn.LeakyReLU(0.2), nn.Linear(256, 128), nn.LeakyReLU(0.2), nn.Linear(128, 1), nn.Sigmoid() ) def forward(self, x): return self.model(x) # Hyperparameters latent_dim = 100 image_dim = 28 * 28 # MNIST images are 28x28 pixels lr = 0.0002 batch_size = 64 # Prepare the data transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) dataset = datasets.MNIST(root='mnist_data', train=True, transform=transform, download=True) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Initialize the models generator = Generator(latent_dim, image_dim) discriminator = Discriminator(image_dim) # Optimizers optimizer_G = optim.Adam(generator.parameters(), lr=lr) optimizer_D = optim.Adam(discriminator.parameters(), lr=lr) # Loss function criterion = nn.BCELoss() # Streamlit interface st.title("GAN with PyTorch and Hugging Face") st.write("Training a GAN to generate MNIST digits") # Slider for epochs epochs = st.slider("Number of Epochs", min_value=1, max_value=100, value=50) train_gan = st.button("Train GAN") if train_gan: # Training loop for epoch in range(epochs): for i, (imgs, _) in enumerate(dataloader): # Prepare real and fake data real_imgs = imgs.view(imgs.size(0), -1) real_labels = torch.ones(imgs.size(0), 1) fake_labels = torch.zeros(imgs.size(0), 1) z = torch.randn(imgs.size(0), latent_dim) fake_imgs = generator(z) # Train Discriminator optimizer_D.zero_grad() real_loss = criterion(discriminator(real_imgs), real_labels) fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels) d_loss = real_loss + fake_loss d_loss.backward() optimizer_D.step() # Train Generator optimizer_G.zero_grad() g_loss = criterion(discriminator(fake_imgs), real_labels) g_loss.backward() optimizer_G.step() st.write(f"Epoch [{epoch+1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}") st.write("Training completed") # Generate and display images z = torch.randn(16, latent_dim) generated_imgs = generator(z).view(-1, 1, 28, 28).detach().cpu().numpy() fig, axes = plt.subplots(4, 4, figsize=(8, 8)) for img, ax in zip(generated_imgs, axes.flatten()): ax.imshow(img.reshape(28, 28), cmap="gray") ax.axis('off') st.pyplot(fig) else: st.write("Use the slider to select the number of epochs and click the button to start training the GAN")