cGAN / app.py
Jihene's picture
Update app.py
c93ed73
from cgan import Generator
import gradio as gr
import torch
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_pil_image
latent_dim = 100
n_classes = 10
img_size = 32
channels = 1
model = Generator()
model.load_state_dict(torch.load("generator1.pth", map_location=torch.device('cpu')))
model.eval()
def generate_image(class_idx):
with torch.no_grad():
# Generate random noise vector of latent_dim size
noise = torch.randn(1, latent_dim)
label = torch.tensor([int(class_idx)])
gen_img = model(noise, label).squeeze(0)
return to_pil_image(make_grid(gen_img, normalize=True))
# Create Gradio Interface
noise_input = gr.inputs.Slider(minimum=-1.0, maximum=1.0, default=0, step=0.1, label="Noise")
class_input = gr.inputs.Dropdown([str(i) for i in range(n_classes)], label="Class")
output_image = gr.outputs.Image('pil')
gr.Interface(
fn=generate_image,
inputs=[class_input],
outputs=output_image,
title="MNIST Generator",
description="Generate images of handwritten digits from the MNIST dataset using a GAN.",
theme="default",
layout="vertical",
live=True
).launch(debug=True)