rajsecrets0 commited on
Commit
2b03e8b
1 Parent(s): 3c08d3a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +72 -0
  2. cifar10_model.pth +3 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from torch import nn
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ # Define the model architecture (same as before)
9
+ class SimpleCNN(nn.Module):
10
+ def __init__(self):
11
+ super(SimpleCNN, self).__init__()
12
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
13
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
14
+ self.pool = nn.MaxPool2d(2, 2)
15
+ self.fc1 = nn.Linear(64 * 8 * 8, 512)
16
+ self.fc2 = nn.Linear(512, 10)
17
+
18
+ def forward(self, x):
19
+ x = self.pool(torch.relu(self.conv1(x)))
20
+ x = self.pool(torch.relu(self.conv2(x)))
21
+ x = x.view(-1, 64 * 8 * 8)
22
+ x = torch.relu(self.fc1(x))
23
+ x = self.fc2(x)
24
+ return x
25
+
26
+ # Load the trained model
27
+ @st.cache_resource
28
+ def load_model():
29
+ model = SimpleCNN()
30
+ model.load_state_dict(torch.load('cifar10_model.pth', map_location=torch.device('cpu')))
31
+ model.eval()
32
+ return model
33
+
34
+ # Define class names
35
+ class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
36
+
37
+ # Define image transformation
38
+ transform = transforms.Compose([
39
+ transforms.Resize((32, 32)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
42
+ ])
43
+
44
+ # Streamlit app
45
+ st.title('CIFAR-10 Image Classification')
46
+
47
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
48
+
49
+ if uploaded_file is not None:
50
+ image = Image.open(uploaded_file)
51
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
52
+
53
+ # Preprocess the image
54
+ input_tensor = transform(image).unsqueeze(0)
55
+
56
+ # Load model and make prediction
57
+ model = load_model()
58
+ with torch.no_grad():
59
+ output = model(input_tensor)
60
+
61
+ # Get the predicted class
62
+ _, predicted_idx = torch.max(output, 1)
63
+ predicted_class = class_names[predicted_idx.item()]
64
+
65
+ # Display the result
66
+ st.write(f"Prediction: {predicted_class}")
67
+
68
+ # Display probabilities
69
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
70
+ st.write("Class Probabilities:")
71
+ for i, prob in enumerate(probabilities):
72
+ st.write(f"{class_names[i]}: {prob.item():.2%}")
cifar10_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1323f6ba07e72febc1e1736184c24752394c58f6611c95f1d21f70e733c52f76
3
+ size 8491976