rajsecrets0 commited on
Commit
d12ec0c
·
verified ·
1 Parent(s): c790cdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -137
app.py CHANGED
@@ -1,138 +1,156 @@
1
- import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- import torchvision
6
- import torchvision.transforms as transforms
7
- from PIL import Image
8
- import io
9
-
10
- # Set page config
11
- st.set_page_config(page_title="CIFAR-10 Classifier", layout="centered", initial_sidebar_state="collapsed")
12
-
13
- # Custom CSS for dark theme
14
- st.markdown("""
15
- <style>
16
- .stApp {
17
- background-color: #0E1117;
18
- color: #FAFAFA;
19
- }
20
- .stButton>button {
21
- background-color: #4CAF50;
22
- color: white;
23
- }
24
- .stHeader {
25
- background-color: #262730;
26
- color: white;
27
- padding: 1rem;
28
- border-radius: 5px;
29
- }
30
- .stImage {
31
- background-color: #262730;
32
- padding: 10px;
33
- border-radius: 5px;
34
- }
35
- .stSuccess {
36
- background-color: #262730;
37
- color: #4CAF50;
38
- padding: 10px;
39
- border-radius: 5px;
40
- }
41
- </style>
42
- """, unsafe_allow_html=True)
43
-
44
- # Model definition
45
- class SimpleCNN(nn.Module):
46
- def __init__(self):
47
- super(SimpleCNN, self).__init__()
48
- self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
49
- self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
50
- self.pool = nn.MaxPool2d(2, 2)
51
- self.fc1 = nn.Linear(64 * 8 * 8, 512)
52
- self.fc2 = nn.Linear(512, 10)
53
-
54
- def forward(self, x):
55
- x = self.pool(torch.relu(self.conv1(x)))
56
- x = self.pool(torch.relu(self.conv2(x)))
57
- x = x.view(-1, 64 * 8 * 8)
58
- x = torch.relu(self.fc1(x))
59
- x = self.fc2(x)
60
- return x
61
-
62
- # Function to train the model
63
- @st.cache_resource
64
- def train_model():
65
- transform = transforms.Compose([
66
- transforms.ToTensor(),
67
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
68
- ])
69
-
70
- trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
71
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
72
-
73
- model = SimpleCNN()
74
- criterion = nn.CrossEntropyLoss()
75
- optimizer = optim.Adam(model.parameters(), lr=0.001)
76
-
77
- for epoch in range(5): # Train for 5 epochs
78
- for i, data in enumerate(trainloader, 0):
79
- inputs, labels = data
80
- optimizer.zero_grad()
81
- outputs = model(inputs)
82
- loss = criterion(outputs, labels)
83
- loss.backward()
84
- optimizer.step()
85
-
86
- return model
87
-
88
- # Function to load or train the model
89
- @st.cache_resource
90
- def get_model():
91
- try:
92
- model = SimpleCNN()
93
- model.load_state_dict(torch.load('cifar10_model.pth'))
94
- model.eval()
95
- except:
96
- model = train_model()
97
- torch.save(model.state_dict(), 'cifar10_model.pth')
98
- return model
99
-
100
- # Streamlit app
101
- st.markdown("<h1 class='stHeader'>CIFAR-10 Image Classification</h1>", unsafe_allow_html=True)
102
- st.write("Upload an image to classify it into one of the CIFAR-10 categories.")
103
-
104
- # File uploader
105
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
106
-
107
- if uploaded_file is not None:
108
- # Display uploaded image
109
- image = Image.open(uploaded_file)
110
- st.markdown("<div class='stImage'>", unsafe_allow_html=True)
111
- st.image(image, caption='Uploaded Image', use_column_width=True)
112
- st.markdown("</div>", unsafe_allow_html=True)
113
-
114
- # Predict button
115
- if st.button('Classify Image'):
116
- # Load model
117
- model = get_model()
118
-
119
- # Preprocess image
120
- transform = transforms.Compose([
121
- transforms.Resize((32, 32)),
122
- transforms.ToTensor(),
123
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
124
- ])
125
- input_tensor = transform(image).unsqueeze(0)
126
-
127
- # Make prediction
128
- with torch.no_grad():
129
- output = model(input_tensor)
130
- _, predicted = torch.max(output, 1)
131
-
132
- # Display result
133
- classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
134
- st.markdown(f"<div class='stSuccess'>Prediction: {classes[predicted.item()]}</div>", unsafe_allow_html=True)
135
-
136
- # Footer
137
- st.markdown("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  st.markdown("<p style='text-align: center; color: #666;'>Created with Streamlit and PyTorch</p>", unsafe_allow_html=True)
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torchvision
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+ import io
9
+
10
+ # Set page config
11
+ st.set_page_config(page_title="CIFAR-10 Classifier", layout="wide", initial_sidebar_state="expanded")
12
+
13
+ # Custom CSS for dark theme
14
+ st.markdown("""
15
+ <style>
16
+ .stApp {
17
+ background-color: #0E1117;
18
+ color: #FAFAFA;
19
+ }
20
+ .stButton>button {
21
+ background-color: #4CAF50;
22
+ color: white;
23
+ }
24
+ .stHeader {
25
+ background-color: #262730;
26
+ color: white;
27
+ padding: 1rem;
28
+ border-radius: 5px;
29
+ margin-bottom: 1rem;
30
+ }
31
+ .stImage {
32
+ background-color: #262730;
33
+ padding: 10px;
34
+ border-radius: 5px;
35
+ }
36
+ .stSuccess {
37
+ background-color: #262730;
38
+ color: #4CAF50;
39
+ padding: 10px;
40
+ border-radius: 5px;
41
+ margin-top: 1rem;
42
+ }
43
+ .upload-box {
44
+ border: 2px dashed #4CAF50;
45
+ border-radius: 5px;
46
+ padding: 20px;
47
+ text-align: center;
48
+ cursor: pointer;
49
+ }
50
+ </style>
51
+ """, unsafe_allow_html=True)
52
+
53
+ # Model definition
54
+ class SimpleCNN(nn.Module):
55
+ def __init__(self):
56
+ super(SimpleCNN, self).__init__()
57
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
58
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
59
+ self.pool = nn.MaxPool2d(2, 2)
60
+ self.fc1 = nn.Linear(64 * 8 * 8, 512)
61
+ self.fc2 = nn.Linear(512, 10)
62
+
63
+ def forward(self, x):
64
+ x = self.pool(torch.relu(self.conv1(x)))
65
+ x = self.pool(torch.relu(self.conv2(x)))
66
+ x = x.view(-1, 64 * 8 * 8)
67
+ x = torch.relu(self.fc1(x))
68
+ x = self.fc2(x)
69
+ return x
70
+
71
+ # Function to train the model
72
+ @st.cache_resource
73
+ def train_model():
74
+ transform = transforms.Compose([
75
+ transforms.ToTensor(),
76
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
77
+ ])
78
+
79
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
80
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
81
+
82
+ model = SimpleCNN()
83
+ criterion = nn.CrossEntropyLoss()
84
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
85
+
86
+ for epoch in range(5): # Train for 5 epochs
87
+ for i, data in enumerate(trainloader, 0):
88
+ inputs, labels = data
89
+ optimizer.zero_grad()
90
+ outputs = model(inputs)
91
+ loss = criterion(outputs, labels)
92
+ loss.backward()
93
+ optimizer.step()
94
+
95
+ return model
96
+
97
+ # Function to load or train the model
98
+ @st.cache_resource
99
+ def get_model():
100
+ try:
101
+ model = SimpleCNN()
102
+ model.load_state_dict(torch.load('cifar10_model.pth'))
103
+ model.eval()
104
+ except:
105
+ model = train_model()
106
+ torch.save(model.state_dict(), 'cifar10_model.pth')
107
+ return model
108
+
109
+ # Sidebar
110
+ st.sidebar.title("Upload Image")
111
+ uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
112
+
113
+ # Main content
114
+ st.markdown("<h1 class='stHeader'>CIFAR-10 Image Classification</h1>", unsafe_allow_html=True)
115
+
116
+ # Drag and drop section
117
+ col1, col2, col3 = st.columns([1,2,1])
118
+ # with col2:
119
+ # st.markdown("<div class='upload-box'>Drag and drop image here</div>", unsafe_allow_html=True)
120
+
121
+ # Display uploaded image and make prediction
122
+ if uploaded_file is not None:
123
+ image = Image.open(uploaded_file)
124
+ col1, col2, col3 = st.columns([1,2,1])
125
+ with col2:
126
+ st.markdown("<div class='stImage'>", unsafe_allow_html=True)
127
+ st.image(image, caption='Uploaded Image', use_column_width=True)
128
+ st.markdown("</div>", unsafe_allow_html=True)
129
+
130
+ # Load model and make prediction
131
+ model = get_model()
132
+ transform = transforms.Compose([
133
+ transforms.Resize((32, 32)),
134
+ transforms.ToTensor(),
135
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
136
+ ])
137
+ input_tensor = transform(image).unsqueeze(0)
138
+
139
+ with torch.no_grad():
140
+ output = model(input_tensor)
141
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
142
+
143
+ # Display results in sidebar
144
+ classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
145
+ _, predicted = torch.max(output, 1)
146
+ st.sidebar.markdown("<div class='stSuccess'>", unsafe_allow_html=True)
147
+ st.sidebar.write(f"Best Prediction: {classes[predicted.item()]}")
148
+ st.sidebar.markdown("</div>", unsafe_allow_html=True)
149
+
150
+ st.sidebar.write("Prediction Probabilities:")
151
+ for i, prob in enumerate(probabilities):
152
+ st.sidebar.write(f"{classes[i]}: {prob.item():.2%}")
153
+
154
+ # Footer
155
+ st.markdown("---")
156
  st.markdown("<p style='text-align: center; color: #666;'>Created with Streamlit and PyTorch</p>", unsafe_allow_html=True)