|
from cgan import Generator |
|
import gradio as gr |
|
import torch |
|
from torchvision.utils import make_grid |
|
from torchvision.transforms.functional import to_pil_image |
|
|
|
latent_dim = 100 |
|
n_classes = 10 |
|
img_size = 32 |
|
channels = 1 |
|
|
|
model = Generator() |
|
model.load_state_dict(torch.load("generator1.pth", map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
|
|
def generate_image(class_idx): |
|
with torch.no_grad(): |
|
|
|
|
|
noise = torch.randn(1, latent_dim) |
|
label = torch.tensor([int(class_idx)]) |
|
gen_img = model(noise, label).squeeze(0) |
|
return to_pil_image(make_grid(gen_img, normalize=True)) |
|
|
|
|
|
|
|
noise_input = gr.inputs.Slider(minimum=-1.0, maximum=1.0, default=0, step=0.1, label="Noise") |
|
class_input = gr.inputs.Dropdown([str(i) for i in range(n_classes)], label="Class") |
|
output_image = gr.outputs.Image('pil') |
|
|
|
gr.Interface( |
|
fn=generate_image, |
|
inputs=[class_input], |
|
outputs=output_image, |
|
title="MNIST Generator", |
|
description="Generate images of handwritten digits from the MNIST dataset using a GAN.", |
|
theme="default", |
|
layout="vertical", |
|
live=True |
|
).launch(debug=True) |