Jihene commited on
Commit
7862640
·
1 Parent(s): 6c6a596

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install gradio
2
+ import gradio as gr
3
+ import torch
4
+ from torchvision.utils import make_grid
5
+ from torchvision.transforms.functional import to_pil_image
6
+
7
+ latent_dim = 100
8
+ n_classes = 10
9
+ img_size = 32
10
+ channels = 1
11
+
12
+ model = Generator()
13
+ model.load_state_dict(torch.load("generator1.pth", map_location=torch.device('cpu')))
14
+ model.eval()
15
+
16
+
17
+ def generate_image(class_idx):
18
+ with torch.no_grad():
19
+
20
+ # Generate random noise vector of latent_dim size
21
+ noise = torch.randn(1, latent_dim)
22
+ label = torch.tensor([int(class_idx)])
23
+ gen_img = model(noise, label).squeeze(0)
24
+ return to_pil_image(make_grid(gen_img, normalize=True))
25
+
26
+
27
+ # Create Gradio Interface
28
+ noise_input = gr.inputs.Slider(minimum=-1.0, maximum=1.0, default=0, step=0.1, label="Noise")
29
+ class_input = gr.inputs.Dropdown([str(i) for i in range(n_classes)], label="Class")
30
+ output_image = gr.outputs.Image('pil')
31
+
32
+ gr.Interface(
33
+ fn=generate_image,
34
+ inputs=[class_input],
35
+ outputs=output_image,
36
+ title="MNIST Generator",
37
+ description="Generate images of handwritten digits from the MNIST dataset using a GAN.",
38
+ theme="default",
39
+ layout="vertical",
40
+ live=True
41
+ ).launch(debug=True)