SonFox2920 commited on
Commit
435eb8d
·
verified ·
1 Parent(s): e30bd59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -183
app.py CHANGED
@@ -3,21 +3,12 @@ import tensorflow as tf
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
6
- import io
7
- import os
8
- import cv2
9
- import numpy as np
10
- from tensorflow import keras
11
- from tensorflow.keras import layers, models
12
- from sklearn.preprocessing import StandardScaler
13
- from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
14
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
15
- import matplotlib.pyplot as plt
16
- import random
17
  from tensorflow.keras import layers, models
18
  from tensorflow.keras.applications import EfficientNetB0
19
  from tensorflow.keras.applications.efficientnet import preprocess_input
20
- from tensorflow.keras.layers import Lambda # Đảm bảo nhập Lambda từ tensorflow.keras.layers
 
 
21
  # Set page config
22
  st.set_page_config(
23
  page_title="Stone Classification",
@@ -25,7 +16,7 @@ st.set_page_config(
25
  layout="wide"
26
  )
27
 
28
- # Custom CSS to improve the appearance
29
  st.markdown("""
30
  <style>
31
  .main {
@@ -35,10 +26,6 @@ st.markdown("""
35
  width: 100%;
36
  margin-top: 1rem;
37
  }
38
- .upload-text {
39
- text-align: center;
40
- padding: 2rem;
41
- }
42
  .prediction-card {
43
  padding: 2rem;
44
  border-radius: 0.5rem;
@@ -52,202 +39,159 @@ st.markdown("""
52
  border-radius: 0.5rem;
53
  box-shadow: 0 1px 3px rgba(0,0,0,0.12);
54
  }
55
- .prediction-bar {
56
- display: flex;
57
- align-items: center;
58
- margin: 0.5rem 0;
59
- }
60
- .prediction-label {
61
- width: 100px;
62
- font-weight: 500;
63
- }
64
  </style>
65
  """, unsafe_allow_html=True)
66
 
67
  @st.cache_resource
68
  def load_model():
69
  """Load the trained model"""
70
- return tf.keras.models.load_model('mlp_model.h5')
71
-
 
 
 
 
72
  def color_histogram(image, bins=16):
73
- # (Previous implementation remains the same)
74
  hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
75
  hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
76
  hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
77
 
78
- hist_r = hist_r / np.sum(hist_r)
79
- hist_g = hist_g / np.sum(hist_g)
80
- hist_b = hist_b / np.sum(hist_b)
 
81
 
82
  return np.concatenate([hist_r, hist_g, hist_b])
83
 
84
  def color_moments(image):
85
- # (Previous implementation remains the same)
86
  img = image.astype(np.float32) / 255.0
87
-
88
  moments = []
89
- for i in range(3): # For each color channel
 
90
  channel = img[:,:,i]
91
-
92
  mean = np.mean(channel)
93
- std = np.std(channel)
94
- skewness = np.mean(((channel - mean) / std) ** 3)
95
-
96
  moments.extend([mean, std, skewness])
97
 
98
  return np.array(moments)
99
 
100
  def dominant_color_descriptor(image, k=3):
101
- # (Previous implementation remains the same)
102
- pixels = image.reshape(-1, 3)
103
 
104
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
105
  flags = cv2.KMEANS_RANDOM_CENTERS
106
 
107
  try:
108
- _, labels, centers = cv2.kmeans(pixels.astype(np.float32), k, None, criteria, 10, flags)
109
-
110
  unique, counts = np.unique(labels, return_counts=True)
111
  percentages = counts / len(labels)
112
-
113
- dominant_colors = centers.flatten()
114
- color_percentages = percentages
115
-
116
- return np.concatenate([dominant_colors, color_percentages])
117
- except:
118
- return np.zeros(2 * k)
119
 
120
  def color_coherence_vector(image, k=3):
121
- # Convert to grayscale
122
  gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
123
-
124
- # Convert the grayscale image to 8-bit format before applying threshold
125
  gray = np.uint8(gray)
126
 
127
- # Apply Otsu's thresholding method
128
  _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
129
-
130
- # Perform connected components analysis
131
  num_labels, labels = cv2.connectedComponents(binary)
132
 
133
  ccv = []
134
  for i in range(1, min(k+1, num_labels)):
135
  region_mask = (labels == i)
136
  total_pixels = np.sum(region_mask)
137
- coherent_pixels = total_pixels
138
-
139
- ccv.extend([coherent_pixels, total_pixels])
140
-
141
- while len(ccv) < 2 * k:
142
- ccv.append(0)
143
-
144
- return np.array(ccv)
145
-
146
-
147
- # ViT and Feature Extraction Functions (from previous implementation)
148
- # (Keeping the Patches, PatchEncoder, and create_vit_feature_extractor functions)
149
-
150
- def extract_features(image):
151
- """
152
- Extract multiple features from an image
153
- """
154
- color_hist = color_histogram(image)
155
- color_mom = color_moments(image)
156
- dom_color = dominant_color_descriptor(image)
157
- ccv = color_coherence_vector(image)
158
 
159
- return np.concatenate([color_hist, color_mom, dom_color, ccv])
160
-
161
- from transformers import ViTFeatureExtractor, ViTModel
162
- import torch
163
- from tensorflow.keras import layers, models
164
 
165
- def create_vit_feature_extractor(input_shape=(256, 256, 3), num_classes=None):
166
- # Xây dựng mô hình ViT đã huấn luyện sẵn từ TensorFlow
 
 
167
  inputs = layers.Input(shape=input_shape)
 
168
 
169
- # Thêm lớp Lambda để tiền xử lý ảnh
170
- x = Lambda(preprocess_input, output_shape=input_shape)(inputs) # Xử lý ảnh đầu vào
171
-
172
- # Bạn có thể thay thế phần này bằng một mô hình ViT đã được huấn luyện sẵn.
173
- # Dưới đây là ví dụ dùng EfficientNetB0 thay vì ViT.
174
- # Tạo mô hình ViT hoặc sử dụng mô hình khác đã được huấn luyện sẵn
175
- vit_model = EfficientNetB0(include_top=False, weights='imagenet', input_tensor=x)
176
-
177
- # Trích xuất đặc trưng từ mô hình ViT
178
- x = layers.GlobalAveragePooling2D()(vit_model.output)
179
-
180
- if num_classes:
181
- x = layers.Dense(num_classes, activation='softmax')(x) # Thêm lớp phân loại (nếu có)
182
 
 
183
  return models.Model(inputs=inputs, outputs=x)
 
 
 
 
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def preprocess_image(image):
186
  """Preprocess the uploaded image"""
187
  # Convert to RGB if needed
188
  if image.mode != 'RGB':
189
  image = image.convert('RGB')
190
 
191
- # Convert to numpy array
192
  img_array = np.array(image)
193
-
194
- # Ensure RGB format
195
- if len(img_array.shape) == 2: # Grayscale
196
- img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
197
- elif img_array.shape[2] == 4: # RGBA
198
- img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
199
-
200
- # Resize
201
  img_array = cv2.resize(img_array, (256, 256))
202
-
203
- # Normalize
204
  img_array = img_array.astype('float32') / 255.0
205
 
206
  # Extract traditional features
207
- image_features = extract_features(img_array)
208
 
209
- # Create and process ViT features
210
- vit_extractor = create_vit_feature_extractor()
211
-
212
- # Reshape image for ViT processing - THIS IS THE KEY FIX
213
- img_for_vit = np.expand_dims(img_array, axis=0) # Add batch dimension
214
- image_vit = vit_extractor.predict(img_for_vit)
215
-
216
- # Flatten ViT features if needed
217
- image_vit = image_vit.reshape(1, -1) # Ensure 2D shape
218
 
219
  # Combine features
220
- image_combined = np.concatenate([image_features.reshape(1, -1), image_vit], axis=1)
 
 
 
221
 
222
  # Scale features
223
  scaler = StandardScaler()
224
- image_scaled = scaler.fit_transform(image_combined)
225
-
226
- return image_scaled.squeeze() # Remove any unnecessary dimensions
227
 
228
  def get_top_predictions(prediction, class_names, top_k=5):
229
  """Get top k predictions with their probabilities"""
230
- # Get indices of top k predictions
231
  top_indices = prediction.argsort()[0][-top_k:][::-1]
232
-
233
- # Get corresponding class names and probabilities
234
- top_predictions = [
235
  (class_names[i], float(prediction[0][i]) * 100)
236
  for i in top_indices
237
  ]
238
-
239
- return top_predictions
240
 
241
  def main():
242
- # Title
243
  st.title("🪨 Stone Classification")
244
  st.write("Upload an image of a stone to classify its type")
245
 
246
- # Initialize session state for prediction if not exists
247
  if 'predictions' not in st.session_state:
248
  st.session_state.predictions = None
249
 
250
- # Create two columns
251
  col1, col2 = st.columns(2)
252
 
253
  with col1:
@@ -255,69 +199,60 @@ def main():
255
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
256
 
257
  if uploaded_file is not None:
258
- # Display uploaded image
259
- image = Image.open(uploaded_file)
260
- st.image(image, caption="Uploaded Image", use_column_width=True)
261
-
262
- with st.spinner('Analyzing image...'):
263
- try:
264
- # Load model
265
  model = load_model()
 
 
 
266
 
267
- # Preprocess image
268
  processed_image = preprocess_image(image)
 
269
 
270
- # Ensure correct shape for prediction
271
- processed_image = np.expand_dims(processed_image, axis=0)
272
-
273
- # Make prediction
274
- prediction = model.predict(processed_image)
275
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
 
276
 
277
- # Get top 5 predictions
278
- top_predictions = get_top_predictions(prediction, class_names)
279
-
280
- # Store in session state
281
- st.session_state.predictions = top_predictions
282
-
283
- except Exception as e:
284
- st.error(f"Error during prediction: {str(e)}")
285
 
286
  with col2:
287
  st.subheader("Prediction Results")
288
- if st.session_state.predictions is not None:
289
- # Create a card-like container for results
290
- results_container = st.container()
291
- with results_container:
292
- # Display main prediction
293
- st.markdown("<div class='prediction-card'>", unsafe_allow_html=True)
294
- top_class, top_confidence = st.session_state.predictions[0]
295
- st.markdown(f"### Primary Prediction: Grade {top_class}")
296
- st.markdown(f"### Confidence: {top_confidence:.2f}%")
297
- st.markdown("</div>", unsafe_allow_html=True)
298
-
299
- # Display confidence bar for top prediction
300
- st.progress(top_confidence / 100)
301
-
302
- # Display top 5 predictions
303
- st.markdown("### Top 5 Predictions")
304
- st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
305
-
306
- # Create a Streamlit container for the predictions
307
- for class_name, confidence in st.session_state.predictions:
308
- col_label, col_bar, col_value = st.columns([2, 6, 2])
309
- with col_label:
310
- st.write(f"Grade {class_name}")
311
- with col_bar:
312
- st.progress(confidence / 100)
313
- with col_value:
314
- st.write(f"{confidence:.2f}%")
315
-
316
- st.markdown("</div>", unsafe_allow_html=True)
 
317
  else:
318
- st.info("Upload an image and click 'Predict' to see the results")
319
 
320
- # Footer
321
  st.markdown("---")
322
  st.markdown("Made with ❤️ using Streamlit")
323
 
 
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
6
  from tensorflow.keras import layers, models
7
  from tensorflow.keras.applications import EfficientNetB0
8
  from tensorflow.keras.applications.efficientnet import preprocess_input
9
+ from sklearn.preprocessing import StandardScaler
10
+ import io
11
+
12
  # Set page config
13
  st.set_page_config(
14
  page_title="Stone Classification",
 
16
  layout="wide"
17
  )
18
 
19
+ # Custom CSS with improved styling
20
  st.markdown("""
21
  <style>
22
  .main {
 
26
  width: 100%;
27
  margin-top: 1rem;
28
  }
 
 
 
 
29
  .prediction-card {
30
  padding: 2rem;
31
  border-radius: 0.5rem;
 
39
  border-radius: 0.5rem;
40
  box-shadow: 0 1px 3px rgba(0,0,0,0.12);
41
  }
 
 
 
 
 
 
 
 
 
42
  </style>
43
  """, unsafe_allow_html=True)
44
 
45
  @st.cache_resource
46
  def load_model():
47
  """Load the trained model"""
48
+ try:
49
+ return tf.keras.models.load_model('mlp_model.h5')
50
+ except Exception as e:
51
+ st.error(f"Error loading model: {str(e)}")
52
+ return None
53
+
54
  def color_histogram(image, bins=16):
55
+ """Calculate color histogram features"""
56
  hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
57
  hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
58
  hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
59
 
60
+ # Normalize histograms
61
+ hist_r = hist_r / (np.sum(hist_r) + 1e-7)
62
+ hist_g = hist_g / (np.sum(hist_g) + 1e-7)
63
+ hist_b = hist_b / (np.sum(hist_b) + 1e-7)
64
 
65
  return np.concatenate([hist_r, hist_g, hist_b])
66
 
67
  def color_moments(image):
68
+ """Calculate color moments features"""
69
  img = image.astype(np.float32) / 255.0
 
70
  moments = []
71
+
72
+ for i in range(3):
73
  channel = img[:,:,i]
 
74
  mean = np.mean(channel)
75
+ std = np.std(channel) + 1e-7 # Avoid division by zero
76
+ skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
 
77
  moments.extend([mean, std, skewness])
78
 
79
  return np.array(moments)
80
 
81
  def dominant_color_descriptor(image, k=3):
82
+ """Calculate dominant color descriptor"""
83
+ pixels = image.reshape(-1, 3).astype(np.float32)
84
 
85
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
86
  flags = cv2.KMEANS_RANDOM_CENTERS
87
 
88
  try:
89
+ _, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, flags)
 
90
  unique, counts = np.unique(labels, return_counts=True)
91
  percentages = counts / len(labels)
92
+ return np.concatenate([centers.flatten(), percentages])
93
+ except Exception:
94
+ return np.zeros(k * 4) # Return zero vector if clustering fails
 
 
 
 
95
 
96
  def color_coherence_vector(image, k=3):
97
+ """Calculate color coherence vector"""
98
  gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
 
 
99
  gray = np.uint8(gray)
100
 
 
101
  _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
 
 
102
  num_labels, labels = cv2.connectedComponents(binary)
103
 
104
  ccv = []
105
  for i in range(1, min(k+1, num_labels)):
106
  region_mask = (labels == i)
107
  total_pixels = np.sum(region_mask)
108
+ ccv.extend([total_pixels, total_pixels])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ # Pad with zeros if needed
111
+ ccv.extend([0] * (2 * k - len(ccv)))
112
+ return np.array(ccv[:2*k])
 
 
113
 
114
+ @st.cache_resource
115
+ def create_feature_extractor():
116
+ """Create and cache the feature extractor model"""
117
+ input_shape = (256, 256, 3)
118
  inputs = layers.Input(shape=input_shape)
119
+ x = layers.Lambda(preprocess_input)(inputs)
120
 
121
+ base_model = EfficientNetB0(
122
+ include_top=False,
123
+ weights='imagenet',
124
+ input_tensor=x
125
+ )
 
 
 
 
 
 
 
 
126
 
127
+ x = layers.GlobalAveragePooling2D()(base_model.output)
128
  return models.Model(inputs=inputs, outputs=x)
129
+
130
+ def extract_features(image):
131
+ """Extract all features from an image"""
132
+ # Convert image to uint8 for OpenCV operations
133
+ image_uint8 = (image * 255).astype(np.uint8)
134
 
135
+ # Extract traditional features
136
+ hist_features = color_histogram(image_uint8)
137
+ moment_features = color_moments(image_uint8)
138
+ dominant_features = dominant_color_descriptor(image_uint8)
139
+ ccv_features = color_coherence_vector(image_uint8)
140
+
141
+ return np.concatenate([
142
+ hist_features,
143
+ moment_features,
144
+ dominant_features,
145
+ ccv_features
146
+ ])
147
+
148
  def preprocess_image(image):
149
  """Preprocess the uploaded image"""
150
  # Convert to RGB if needed
151
  if image.mode != 'RGB':
152
  image = image.convert('RGB')
153
 
154
+ # Convert to numpy array and resize
155
  img_array = np.array(image)
 
 
 
 
 
 
 
 
156
  img_array = cv2.resize(img_array, (256, 256))
 
 
157
  img_array = img_array.astype('float32') / 255.0
158
 
159
  # Extract traditional features
160
+ traditional_features = extract_features(img_array)
161
 
162
+ # Extract deep features
163
+ feature_extractor = create_feature_extractor()
164
+ deep_features = feature_extractor.predict(
165
+ np.expand_dims(img_array, axis=0),
166
+ verbose=0
167
+ )
 
 
 
168
 
169
  # Combine features
170
+ combined_features = np.concatenate([
171
+ traditional_features.reshape(1, -1),
172
+ deep_features.reshape(1, -1)
173
+ ], axis=1)
174
 
175
  # Scale features
176
  scaler = StandardScaler()
177
+ return scaler.fit_transform(combined_features)
 
 
178
 
179
  def get_top_predictions(prediction, class_names, top_k=5):
180
  """Get top k predictions with their probabilities"""
 
181
  top_indices = prediction.argsort()[0][-top_k:][::-1]
182
+ return [
 
 
183
  (class_names[i], float(prediction[0][i]) * 100)
184
  for i in top_indices
185
  ]
 
 
186
 
187
  def main():
 
188
  st.title("🪨 Stone Classification")
189
  st.write("Upload an image of a stone to classify its type")
190
 
191
+ # Initialize session state
192
  if 'predictions' not in st.session_state:
193
  st.session_state.predictions = None
194
 
 
195
  col1, col2 = st.columns(2)
196
 
197
  with col1:
 
199
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
200
 
201
  if uploaded_file is not None:
202
+ try:
203
+ image = Image.open(uploaded_file)
204
+ st.image(image, caption="Uploaded Image", use_column_width=True)
205
+
206
+ with st.spinner('Analyzing image...'):
 
 
207
  model = load_model()
208
+ if model is None:
209
+ st.error("Failed to load model")
210
+ return
211
 
 
212
  processed_image = preprocess_image(image)
213
+ prediction = model.predict(processed_image, verbose=0)
214
 
 
 
 
 
 
215
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
216
+ st.session_state.predictions = get_top_predictions(prediction, class_names)
217
 
218
+ except Exception as e:
219
+ st.error(f"Error processing image: {str(e)}")
 
 
 
 
 
 
220
 
221
  with col2:
222
  st.subheader("Prediction Results")
223
+ if st.session_state.predictions:
224
+ # Display main prediction
225
+ top_class, top_confidence = st.session_state.predictions[0]
226
+ st.markdown(
227
+ f"""
228
+ <div class='prediction-card'>
229
+ <h3>Primary Prediction: Grade {top_class}</h3>
230
+ <h3>Confidence: {top_confidence:.2f}%</h3>
231
+ </div>
232
+ """,
233
+ unsafe_allow_html=True
234
+ )
235
+
236
+ # Display confidence bar
237
+ st.progress(top_confidence / 100)
238
+
239
+ # Display top 5 predictions
240
+ st.markdown("### Top 5 Predictions")
241
+ st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
242
+
243
+ for class_name, confidence in st.session_state.predictions:
244
+ cols = st.columns([2, 6, 2])
245
+ with cols[0]:
246
+ st.write(f"Grade {class_name}")
247
+ with cols[1]:
248
+ st.progress(confidence / 100)
249
+ with cols[2]:
250
+ st.write(f"{confidence:.2f}%")
251
+
252
+ st.markdown("</div>", unsafe_allow_html=True)
253
  else:
254
+ st.info("Upload an image to see the predictions")
255
 
 
256
  st.markdown("---")
257
  st.markdown("Made with ❤️ using Streamlit")
258