neelimapreeti297 commited on
Commit
e01b6a7
·
verified ·
1 Parent(s): 973f8ce
Files changed (4) hide show
  1. app.py +33 -0
  2. app_data/cat.jpg +0 -0
  3. app_data/dog.jpg +0 -0
  4. app_data/panda.jpg +0 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+
5
+ model = torch.jit.load("./models/cat_dog_cnn.pt")
6
+ model.eval()
7
+
8
+
9
+ transform = transforms.Compose([
10
+ transforms.Resize((224,224)),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
13
+ ])
14
+
15
+ CLASSES = ["Cat", "Dog", "Panda"]
16
+
17
+ def classify_image(inp):
18
+ inp = transform(inp).unsqueeze(0)
19
+ out = model(inp)
20
+ return CLASSES[out.argmax().item()]
21
+
22
+ iface = gr.Interface(fn=classify_image,
23
+ inputs=gr.Image(type="pil", label="Input Image"),
24
+ outputs="text",
25
+ examples=[
26
+
27
+ "./app_data/cat.jpg",
28
+ "./app_data/dog.jpg",
29
+ "./app_data/panda.jpg",
30
+
31
+
32
+ ])
33
+ iface.launch()
app_data/cat.jpg ADDED
app_data/dog.jpg ADDED
app_data/panda.jpg ADDED