shrestha-prabin commited on
Commit
392c46c
·
1 Parent(s): 931038c

add classifier

Browse files
Files changed (2) hide show
  1. best_model.pth +3 -0
  2. src/streamlit_app.py +175 -37
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3870cefaa28a3803fddac25f8316880e0fadba580910acded5185df3eebba82e
3
+ size 78129506
src/streamlit_app.py CHANGED
@@ -1,40 +1,178 @@
1
- import altair as alt
2
  import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
  import numpy as np
 
3
  import streamlit as st
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as T
8
+ from PIL import Image
9
 
10
+ st.set_page_config(page_title="Garbage Classification")
11
+
12
+
13
+ # CNN Model Definition
14
+ class SimpleCNN(nn.Module):
15
+ def __init__(self, num_classes, input_channels=3):
16
+ super().__init__()
17
+
18
+ # Convolutional layers
19
+ self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=0)
20
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
21
+
22
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=0)
23
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
24
+
25
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=0)
26
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
27
+
28
+ self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=0)
29
+ self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
30
+
31
+ self.flatten = nn.Flatten()
32
+
33
+ # Dense layers
34
+ self.fc1 = nn.Linear(256 * 12 * 12, 512)
35
+ self.dropout1 = nn.Dropout(0.5)
36
+
37
+ self.fc2 = nn.Linear(512, 512)
38
+ self.dropout2 = nn.Dropout(0.5)
39
+
40
+ self.fc3 = nn.Linear(512, num_classes)
41
+
42
+ def forward(self, x):
43
+ # Conv blocks
44
+ x = F.relu(self.conv1(x))
45
+ x = self.pool1(x)
46
+
47
+ x = F.relu(self.conv2(x))
48
+ x = self.pool2(x)
49
+
50
+ x = F.relu(self.conv3(x))
51
+ x = self.pool3(x)
52
+
53
+ x = F.relu(self.conv4(x))
54
+ x = self.pool4(x)
55
+
56
+ # Dense layers
57
+ x = self.flatten(x)
58
+ x = F.relu(self.fc1(x))
59
+ x = self.dropout1(x)
60
+
61
+ x = F.relu(self.fc2(x))
62
+ x = self.dropout2(x)
63
+
64
+ x = self.fc3(x)
65
+ return x
66
+
67
+
68
+ # Class names
69
+ CLASS_NAMES = [
70
+ "battery",
71
+ "biological",
72
+ "cardboard",
73
+ "clothes",
74
+ "glass",
75
+ "metal",
76
+ "paper",
77
+ "plastic",
78
+ "shoes",
79
+ "trash",
80
+ ]
81
+
82
+
83
+ # Cache the model loading
84
+ @st.cache_resource
85
+ def load_model():
86
+ """Load the trained model"""
87
+ device = torch.device("cpu")
88
+ model = SimpleCNN(num_classes=10)
89
+ model = nn.DataParallel(model)
90
+
91
+ try:
92
+ model.load_state_dict(torch.load("best_model.pth", map_location=device))
93
+ model.eval()
94
+ return model, device
95
+ except Exception as e:
96
+ st.error(f"Error loading model: {e}")
97
+ return None, device
98
+
99
+
100
+ def preprocess_image(image):
101
+ """Preprocess uploaded image"""
102
+ transform = T.Compose(
103
+ [
104
+ T.Resize(224),
105
+ T.CenterCrop(224),
106
+ T.ToTensor(),
107
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
108
+ ]
109
+ )
110
+
111
+ image_tensor = transform(image).unsqueeze(0)
112
+ return image_tensor
113
+
114
+
115
+ def predict_image(image, model, device):
116
+ """Make prediction on image"""
117
+ # Preprocess image
118
+ input_tensor = preprocess_image(image).to(device)
119
+
120
+ # Make prediction
121
+ with torch.no_grad():
122
+ outputs = model(input_tensor)
123
+ probabilities = F.softmax(outputs, dim=1)
124
+ confidence, predicted_idx = torch.max(probabilities, 1)
125
+
126
+ predicted_class = CLASS_NAMES[predicted_idx.item()]
127
+ confidence_score = confidence.item()
128
+ all_probabilities = probabilities.cpu().numpy().flatten()
129
+
130
+ return predicted_class, confidence_score, all_probabilities
131
+
132
+
133
+ def get_confidence_color(confidence):
134
+ """Get color class based on confidence score"""
135
+ if confidence >= 0.7:
136
+ return "confidence-high"
137
+ elif confidence >= 0.4:
138
+ return "confidence-medium"
139
+ else:
140
+ return "confidence-low"
141
+
142
+
143
+ def main():
144
+ # Load model
145
+ model, device = load_model()
146
+
147
+ # File uploader
148
+ st.header("Garbage Classification")
149
+ uploaded_file = st.file_uploader(
150
+ "Choose an image file",
151
+ type=["jpg", "jpeg", "png"],
152
+ )
153
+
154
+ if uploaded_file is not None:
155
+ # Display uploaded image
156
+ image = Image.open(uploaded_file).convert("RGB")
157
+
158
+ col1, col2 = st.columns([1, 1])
159
+ with col1:
160
+ st.image(image, caption="Uploaded Image", use_container_width=True)
161
+
162
+ # Make prediction
163
+ with st.spinner("🔍 Analyzing image..."):
164
+ predicted_class, confidence, probabilities = predict_image(
165
+ image, model, device
166
+ )
167
+
168
+ sorted_indices = np.argsort(probabilities)[::-1]
169
+
170
+ container = col2.container(border=True)
171
+ for i, idx in enumerate(sorted_indices):
172
+ class_name = CLASS_NAMES[idx]
173
+ prob = probabilities[idx]
174
+ container.write(f"{class_name.title()}: {prob:.1%}")
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()