SonFox2920 commited on
Commit
aed03ad
·
verified ·
1 Parent(s): 544efb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -41
app.py CHANGED
@@ -26,6 +26,28 @@ st.markdown("""
26
  text-align: center;
27
  padding: 2rem;
28
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  </style>
30
  """, unsafe_allow_html=True)
31
 
@@ -71,16 +93,27 @@ def preprocess_image(image):
71
 
72
  return img_array
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def main():
75
  # Title
76
  st.title("🪨 Stone Classification")
77
  st.write("Upload an image of a stone to classify its type")
78
 
79
  # Initialize session state for prediction if not exists
80
- if 'prediction' not in st.session_state:
81
- st.session_state.prediction = None
82
- if 'confidence' not in st.session_state:
83
- st.session_state.confidence = None
84
 
85
  # Create two columns
86
  col1, col2 = st.columns(2)
@@ -96,52 +129,58 @@ def main():
96
 
97
  # Add predict button
98
  if st.button("Predict"):
99
- try:
100
- # Load model
101
- model = load_model()
102
-
103
- # Preprocess image
104
- processed_image = preprocess_image(image)
105
-
106
- # Make prediction
107
- prediction = model.predict(np.expand_dims(processed_image, axis=0))
108
- class_names = ['10', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
109
-
110
- # Get prediction and confidence
111
- predicted_class = class_names[np.argmax(prediction)]
112
- confidence = float(np.max(prediction)) * 100
113
-
114
- # Store in session state
115
- st.session_state.prediction = predicted_class
116
- st.session_state.confidence = confidence
117
-
118
- except Exception as e:
119
- st.error(f"Error during prediction: {str(e)}")
120
 
121
  with col2:
122
  st.subheader("Prediction Results")
123
- if st.session_state.prediction is not None:
124
  # Create a card-like container for results
125
  results_container = st.container()
126
  with results_container:
127
- st.markdown("""
128
- <style>
129
- .prediction-card {
130
- padding: 2rem;
131
- border-radius: 0.5rem;
132
- background-color: #f0f2f6;
133
- margin: 1rem 0;
134
- }
135
- </style>
136
- """, unsafe_allow_html=True)
137
-
138
  st.markdown("<div class='prediction-card'>", unsafe_allow_html=True)
139
- st.markdown(f"### Predicted Class: {st.session_state.prediction}")
140
- st.markdown(f"### Confidence: {st.session_state.confidence:.2f}%")
 
141
  st.markdown("</div>", unsafe_allow_html=True)
142
 
143
- # Add confidence bar
144
- st.progress(st.session_state.confidence / 100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  else:
146
  st.info("Upload an image and click 'Predict' to see the results")
147
 
 
26
  text-align: center;
27
  padding: 2rem;
28
  }
29
+ .prediction-card {
30
+ padding: 2rem;
31
+ border-radius: 0.5rem;
32
+ background-color: #f0f2f6;
33
+ margin: 1rem 0;
34
+ }
35
+ .top-predictions {
36
+ margin-top: 2rem;
37
+ padding: 1rem;
38
+ background-color: white;
39
+ border-radius: 0.5rem;
40
+ box-shadow: 0 1px 3px rgba(0,0,0,0.12);
41
+ }
42
+ .prediction-bar {
43
+ display: flex;
44
+ align-items: center;
45
+ margin: 0.5rem 0;
46
+ }
47
+ .prediction-label {
48
+ width: 100px;
49
+ font-weight: 500;
50
+ }
51
  </style>
52
  """, unsafe_allow_html=True)
53
 
 
93
 
94
  return img_array
95
 
96
+ def get_top_predictions(prediction, class_names, top_k=5):
97
+ """Get top k predictions with their probabilities"""
98
+ # Get indices of top k predictions
99
+ top_indices = prediction.argsort()[0][-top_k:][::-1]
100
+
101
+ # Get corresponding class names and probabilities
102
+ top_predictions = [
103
+ (class_names[i], float(prediction[0][i]) * 100)
104
+ for i in top_indices
105
+ ]
106
+
107
+ return top_predictions
108
+
109
  def main():
110
  # Title
111
  st.title("🪨 Stone Classification")
112
  st.write("Upload an image of a stone to classify its type")
113
 
114
  # Initialize session state for prediction if not exists
115
+ if 'predictions' not in st.session_state:
116
+ st.session_state.predictions = None
 
 
117
 
118
  # Create two columns
119
  col1, col2 = st.columns(2)
 
129
 
130
  # Add predict button
131
  if st.button("Predict"):
132
+ with st.spinner('Analyzing image...'):
133
+ try:
134
+ # Load model
135
+ model = load_model()
136
+
137
+ # Preprocess image
138
+ processed_image = preprocess_image(image)
139
+
140
+ # Make prediction
141
+ prediction = model.predict(np.expand_dims(processed_image, axis=0))
142
+ class_names = ['10', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
143
+
144
+ # Get top 5 predictions
145
+ top_predictions = get_top_predictions(prediction, class_names)
146
+
147
+ # Store in session state
148
+ st.session_state.predictions = top_predictions
149
+
150
+ except Exception as e:
151
+ st.error(f"Error during prediction: {str(e)}")
 
152
 
153
  with col2:
154
  st.subheader("Prediction Results")
155
+ if st.session_state.predictions is not None:
156
  # Create a card-like container for results
157
  results_container = st.container()
158
  with results_container:
159
+ # Display main prediction
 
 
 
 
 
 
 
 
 
 
160
  st.markdown("<div class='prediction-card'>", unsafe_allow_html=True)
161
+ top_class, top_confidence = st.session_state.predictions[0]
162
+ st.markdown(f"### Primary Prediction: Grade {top_class}")
163
+ st.markdown(f"### Confidence: {top_confidence:.2f}%")
164
  st.markdown("</div>", unsafe_allow_html=True)
165
 
166
+ # Display confidence bar for top prediction
167
+ st.progress(top_confidence / 100)
168
+
169
+ # Display top 5 predictions
170
+ st.markdown("### Top 5 Predictions")
171
+ st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
172
+
173
+ # Create a Streamlit container for the predictions
174
+ for class_name, confidence in st.session_state.predictions:
175
+ col_label, col_bar, col_value = st.columns([2, 6, 2])
176
+ with col_label:
177
+ st.write(f"Grade {class_name}")
178
+ with col_bar:
179
+ st.progress(confidence / 100)
180
+ with col_value:
181
+ st.write(f"{confidence:.2f}%")
182
+
183
+ st.markdown("</div>", unsafe_allow_html=True)
184
  else:
185
  st.info("Upload an image and click 'Predict' to see the results")
186