ved1beta commited on
Commit
e99930d
·
1 Parent(s): ac456a1
Files changed (1) hide show
  1. app.py +76 -4
app.py CHANGED
@@ -1,7 +1,79 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import numpy as np
8
 
9
+ # Define the same model architecture
10
+ class ConvNet(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.conv1 = nn.Conv2d(3, 32, 3)
14
+ self.pool = nn.MaxPool2d(2, 2)
15
+ self.conv2 = nn.Conv2d(32, 64, 3)
16
+ self.conv3 = nn.Conv2d(64, 64, 3)
17
+ self.fc1 = nn.Linear(64 * 4 * 4, 64)
18
+ self.fc2 = nn.Linear(64, 10)
19
+
20
+ def forward(self, x):
21
+ x = F.relu(self.conv1(x))
22
+ x = self.pool(x)
23
+ x = F.relu(self.conv2(x))
24
+ x = self.pool(x)
25
+ x = F.relu(self.conv3(x))
26
+ x = torch.flatten(x, 1)
27
+ x = F.relu(self.fc1(x))
28
+ x = self.fc2(x)
29
+ return x
30
 
31
+ # Initialize model and load weights
32
+ model = ConvNet()
33
+ model.load_state_dict(torch.load('cnn.pth', map_location=torch.device('cpu')))
34
+ model.eval()
35
+
36
+ # Define classes
37
+ classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
38
+
39
+ # Define preprocessing
40
+ transform = transforms.Compose([
41
+ transforms.Resize((32, 32)),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
44
+ ])
45
+
46
+ def predict(img):
47
+ # Convert to PIL Image if needed
48
+ if not isinstance(img, Image.Image):
49
+ img = Image.fromarray(img)
50
+
51
+ # Preprocess the image
52
+ img = transform(img).unsqueeze(0)
53
+
54
+ # Get predictions
55
+ with torch.no_grad():
56
+ outputs = model(img)
57
+ probabilities = F.softmax(outputs, dim=1)
58
+
59
+ # Get top 3 predictions
60
+ probs, indices = torch.topk(probabilities, 3)
61
+ predictions = []
62
+ for prob, idx in zip(probs[0], indices[0]):
63
+ predictions.append((classes[idx], float(prob)))
64
+
65
+ # Format the results
66
+ return {pred[0]: pred[1] for pred in predictions}
67
+
68
+ # Create Gradio interface
69
+ iface = gr.Interface(
70
+ fn=predict,
71
+ inputs=gr.Image(type="pil"),
72
+ outputs=gr.Label(num_top_classes=3),
73
+ examples=[["example1.jpg"], ["example2.jpg"]], # Optional: Add example images
74
+ title="CIFAR-10 Image Classifier",
75
+ description="Upload an image to classify it into one of these categories: plane, car, bird, cat, deer, dog, frog, horse, ship, or truck"
76
+ )
77
+
78
+ # Launch the app
79
+ iface.launch()