from cgan import Generator import gradio as gr import torch from torchvision.utils import make_grid from torchvision.transforms.functional import to_pil_image latent_dim = 100 n_classes = 10 img_size = 32 channels = 1 model = Generator() model.load_state_dict(torch.load("generator1.pth", map_location=torch.device('cpu'))) model.eval() def generate_image(class_idx): with torch.no_grad(): # Generate random noise vector of latent_dim size noise = torch.randn(1, latent_dim) label = torch.tensor([int(class_idx)]) gen_img = model(noise, label).squeeze(0) return to_pil_image(make_grid(gen_img, normalize=True)) # Create Gradio Interface noise_input = gr.inputs.Slider(minimum=-1.0, maximum=1.0, default=0, step=0.1, label="Noise") class_input = gr.inputs.Dropdown([str(i) for i in range(n_classes)], label="Class") output_image = gr.outputs.Image('pil') gr.Interface( fn=generate_image, inputs=[class_input], outputs=output_image, title="MNIST Generator", description="Generate images of handwritten digits from the MNIST dataset using a GAN.", theme="default", layout="vertical", live=True ).launch(debug=True)