SonFox2920 commited on
Commit
0e5cd3a
·
verified ·
1 Parent(s): 5446b36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -143
app.py CHANGED
@@ -3,6 +3,10 @@ import tensorflow as tf
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
 
 
 
 
6
  import io
7
 
8
  # Set page config
@@ -12,7 +16,7 @@ st.set_page_config(
12
  layout="wide"
13
  )
14
 
15
- # Custom CSS to improve the appearance
16
  st.markdown("""
17
  <style>
18
  .main {
@@ -22,14 +26,10 @@ st.markdown("""
22
  width: 100%;
23
  margin-top: 1rem;
24
  }
25
- .upload-text {
26
- text-align: center;
27
- padding: 2rem;
28
- }
29
  .prediction-card {
30
  padding: 2rem;
31
  border-radius: 0.5rem;
32
- background-color: #f0f2f6;
33
  margin: 1rem 0;
34
  }
35
  .top-predictions {
@@ -39,109 +39,162 @@ st.markdown("""
39
  border-radius: 0.5rem;
40
  box-shadow: 0 1px 3px rgba(0,0,0,0.12);
41
  }
42
- .prediction-bar {
43
- display: flex;
44
- align-items: center;
45
- margin: 0.5rem 0;
46
- }
47
- .prediction-label {
48
- width: 100px;
49
- font-weight: 500;
50
- }
51
  </style>
52
  """, unsafe_allow_html=True)
53
 
 
54
  @st.cache_resource
55
- def load_model():
56
- """Load the trained model"""
57
- return tf.keras.models.load_model('custom_model.h5')
 
 
 
 
 
 
 
58
 
59
- def preprocess_image(image):
60
- """Preprocess the uploaded image"""
61
- # # Convert to RGB if needed
62
- # if image.mode != 'RGB':
63
- # image = image.convert('RGB')
64
 
65
- # Convert to numpy array
66
- img_array = np.array(image)
 
67
 
68
- # # Convert to RGB if needed
69
- # if len(img_array.shape) == 2: # Grayscale
70
- # img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
71
- # elif img_array.shape[2] == 4: # RGBA
72
- # img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
 
73
 
74
- # # Preprocess image similar to training
75
- # img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV)
76
- # img_hsv[:, :, 2] = cv2.equalizeHist(img_hsv[:, :, 2])
77
- # img_array = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
 
 
78
 
79
- # # Adjust brightness
80
- # target_brightness = 150
81
- # current_brightness = np.mean(img_array)
82
- # alpha = target_brightness / (current_brightness + 1e-5)
83
- # img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=0)
84
 
85
- # # Apply Gaussian blur
86
- # img_array = cv2.GaussianBlur(img_array, (5, 5), 0)
87
 
88
- # Resize
89
- img_array = cv2.resize(img_array, (256, 256))
 
 
 
 
 
 
 
 
 
 
90
 
91
- # Normalize
92
- img_array = img_array.astype('float32') / 255.0
93
 
94
- return img_array
95
-
96
- from mega import Mega
97
-
98
- # Đăng nhập vào tài khoản Mega
99
- def upload_to_mega(file_path, folder_name):
100
- """
101
- Upload file to a specific folder on Mega.nz
102
- """
103
- try:
104
- # Đăng nhập vào tài khoản Mega
105
- mega = Mega()
106
- m = mega.login('[email protected]', '01283315889')
107
 
108
- # Tìm thư mục đích
109
- folder = m.find(folder_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- if not folder:
112
- # Nếu thư mục không tồn tại, hiển thị thông báo lỗi
113
- return f"Thư mục '{folder_name}' không tồn tại!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- # Tải tệp lên thư mục
116
- file = m.upload(file_path, folder[0])
117
- return f"Upload thành công! Link: {m.get_upload_link(file)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- except Exception as e:
120
- return f"Lỗi khi tải lên Mega: {str(e)}"
121
-
122
  def get_top_predictions(prediction, class_names, top_k=5):
123
  """Get top k predictions with their probabilities"""
124
- # Get indices of top k predictions
125
  top_indices = prediction.argsort()[0][-top_k:][::-1]
126
-
127
- # Get corresponding class names and probabilities
128
- top_predictions = [
129
  (class_names[i], float(prediction[0][i]) * 100)
130
  for i in top_indices
131
  ]
132
-
133
- return top_predictions
134
 
135
  def main():
136
- # Title
137
  st.title("🪨 Stone Classification")
138
  st.write("Upload an image of a stone to classify its type")
139
 
140
- # Initialize session state for prediction if not exists
 
 
 
 
 
 
141
  if 'predictions' not in st.session_state:
142
  st.session_state.predictions = None
143
 
144
- # Create two columns
145
  col1, col2 = st.columns(2)
146
 
147
  with col1:
@@ -149,30 +202,19 @@ def main():
149
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
150
 
151
  if uploaded_file is not None:
152
- # Display uploaded image
153
- image = Image.open(uploaded_file)
154
- st.image(image, caption="Uploaded Image", use_column_width=True)
155
-
156
- with st.spinner('Analyzing image...'):
157
- try:
158
- # Load model
159
- model = load_model()
160
-
161
- # Preprocess image
162
- processed_image = preprocess_image(image)
163
 
164
- # Make prediction
165
- prediction = model.predict(np.expand_dims(processed_image, axis=0))
166
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
 
167
 
168
- # Get top 5 predictions
169
- top_predictions = get_top_predictions(prediction, class_names)
170
-
171
- # Store in session state
172
- st.session_state.predictions = top_predictions
173
-
174
- except Exception as e:
175
- st.error(f"Error during prediction: {str(e)}")
176
 
177
  with col2:
178
  st.subheader("Prediction Results")
@@ -188,14 +230,14 @@ def main():
188
  """,
189
  unsafe_allow_html=True
190
  )
191
-
192
  # Display confidence bar
193
  st.progress(top_confidence / 100)
194
-
195
  # Display top 5 predictions
196
  st.markdown("### Top 5 Predictions")
197
  st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
198
-
199
  for class_name, confidence in st.session_state.predictions:
200
  cols = st.columns([2, 6, 2])
201
  with cols[0]:
@@ -204,50 +246,11 @@ def main():
204
  st.progress(confidence / 100)
205
  with cols[2]:
206
  st.write(f"{confidence:.2f}%")
207
-
208
- st.markdown("</div>", unsafe_allow_html=True)
209
-
210
- # User Survey
211
- st.markdown("<div class='survey-card'>", unsafe_allow_html=True)
212
- st.markdown("### Model Accuracy Survey")
213
- st.write("Mô hình có dự đoán chính xác màu sắc của đá trong ảnh này không?")
214
-
215
- # Accuracy Confirmation
216
- accuracy = st.radio(
217
- "Đánh giá độ chính xác",
218
- ["Chọn", "Chính xác", "Không chính xác"],
219
- index=0
220
- )
221
-
222
- if accuracy == "Không chính xác":
223
- # Color input for incorrect prediction
224
- correct_color = st.text_input(
225
- "Vui lòng nhập màu sắc chính xác của đá:",
226
- help="Ví dụ: 10, 9.7, 9.5, 9.2, v.v."
227
- )
228
-
229
- if st.button("Gửi phản hồi và tải ảnh"):
230
- if correct_color and st.session_state.uploaded_image:
231
- # Save the image temporarily
232
- temp_image_path = f"temp_image_{hash(uploaded_file.name)}.png"
233
- st.session_state.uploaded_image.save(temp_image_path)
234
-
235
- # Upload to Mega.nz
236
- upload_result = upload_to_mega(temp_image_path, correct_color)
237
- if "Upload thành công" in upload_result:
238
- st.success(upload_result)
239
- else:
240
- st.error(upload_result)
241
-
242
- # Clean up temporary file
243
- os.remove(temp_image_path)
244
- else:
245
- st.warning("Vui lòng nhập màu sắc chính xác")
246
-
247
  st.markdown("</div>", unsafe_allow_html=True)
248
  else:
249
  st.info("Upload an image to see the predictions")
250
-
251
  st.markdown("---")
252
  st.markdown("Made with ❤️ using Streamlit")
253
 
 
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
6
+ from tensorflow.keras import layers, models
7
+ from tensorflow.keras.applications import EfficientNetB0
8
+ from tensorflow.keras.applications.efficientnet import preprocess_input
9
+ import joblib
10
  import io
11
 
12
  # Set page config
 
16
  layout="wide"
17
  )
18
 
19
+ # Custom CSS with improved styling
20
  st.markdown("""
21
  <style>
22
  .main {
 
26
  width: 100%;
27
  margin-top: 1rem;
28
  }
 
 
 
 
29
  .prediction-card {
30
  padding: 2rem;
31
  border-radius: 0.5rem;
32
+ background-color: #d7d7d9;
33
  margin: 1rem 0;
34
  }
35
  .top-predictions {
 
39
  border-radius: 0.5rem;
40
  box-shadow: 0 1px 3px rgba(0,0,0,0.12);
41
  }
 
 
 
 
 
 
 
 
 
42
  </style>
43
  """, unsafe_allow_html=True)
44
 
45
+ # Cache the model loading
46
  @st.cache_resource
47
+ def load_model_and_scaler():
48
+ """Load the trained model and scaler"""
49
+ try:
50
+ model = tf.keras.models.load_model('mlp_model.h5')
51
+ # Tải scaler đã lưu
52
+ scaler = joblib.load('standard_scaler.pkl')
53
+ return model, scaler
54
+ except Exception as e:
55
+ st.error(f"Error loading model or scaler: {str(e)}")
56
+ return None, None
57
 
58
+ def color_histogram(image, bins=16):
59
+ """Calculate color histogram features"""
60
+ hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
61
+ hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
62
+ hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
63
 
64
+ hist_r = hist_r / (np.sum(hist_r) + 1e-7)
65
+ hist_g = hist_g / (np.sum(hist_g) + 1e-7)
66
+ hist_b = hist_b / (np.sum(hist_b) + 1e-7)
67
 
68
+ return np.concatenate([hist_r, hist_g, hist_b])
69
+
70
+ def color_moments(image):
71
+ """Calculate color moments features"""
72
+ img = image.astype(np.float32) / 255.0
73
+ moments = []
74
 
75
+ for i in range(3):
76
+ channel = img[:,:,i]
77
+ mean = np.mean(channel)
78
+ std = np.std(channel) + 1e-7
79
+ skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
80
+ moments.extend([mean, std, skewness])
81
 
82
+ return np.array(moments)
83
+
84
+ def dominant_color_descriptor(image, k=3):
85
+ """Calculate dominant color descriptor"""
86
+ pixels = image.reshape(-1, 3).astype(np.float32)
87
 
88
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
89
+ flags = cv2.KMEANS_RANDOM_CENTERS
90
 
91
+ try:
92
+ _, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, flags)
93
+ unique, counts = np.unique(labels, return_counts=True)
94
+ percentages = counts / len(labels)
95
+ return np.concatenate([centers.flatten(), percentages])
96
+ except Exception:
97
+ return np.zeros(k * 4)
98
+
99
+ def color_coherence_vector(image, k=3):
100
+ """Calculate color coherence vector"""
101
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
102
+ gray = np.uint8(gray)
103
 
104
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
105
+ num_labels, labels = cv2.connectedComponents(binary)
106
 
107
+ ccv = []
108
+ for i in range(1, min(k+1, num_labels)):
109
+ region_mask = (labels == i)
110
+ total_pixels = np.sum(region_mask)
111
+ ccv.extend([total_pixels, total_pixels])
112
+
113
+ ccv.extend([0] * (2 * k - len(ccv)))
114
+ return np.array(ccv[:2*k])
 
 
 
 
 
115
 
116
+ @st.cache_resource
117
+ def create_vit_feature_extractor():
118
+ """Create and cache the ViT feature extractor"""
119
+ input_shape = (256, 256, 3)
120
+ inputs = layers.Input(shape=input_shape)
121
+ x = layers.Lambda(preprocess_input)(inputs)
122
+
123
+ base_model = EfficientNetB0(
124
+ include_top=False,
125
+ weights='imagenet',
126
+ input_tensor=x
127
+ )
128
+
129
+ x = layers.GlobalAveragePooling2D()(base_model.output)
130
+ return models.Model(inputs=inputs, outputs=x)
131
 
132
+ def extract_features(image):
133
+ """Extract all features from an image"""
134
+ # Traditional features
135
+ hist_features = color_histogram(image)
136
+ moment_features = color_moments(image)
137
+ dominant_features = dominant_color_descriptor(image)
138
+ ccv_features = color_coherence_vector(image)
139
+
140
+ traditional_features = np.concatenate([
141
+ hist_features,
142
+ moment_features,
143
+ dominant_features,
144
+ ccv_features
145
+ ])
146
+
147
+ # Deep features using ViT
148
+ feature_extractor = create_vit_feature_extractor()
149
+ vit_features = feature_extractor.predict(
150
+ np.expand_dims(image, axis=0),
151
+ verbose=0
152
+ )
153
+
154
+ # Combine all features
155
+ return np.concatenate([traditional_features, vit_features.flatten()])
156
 
157
+ def preprocess_image(image, scaler):
158
+ """Preprocess the uploaded image"""
159
+ # Convert to RGB if needed
160
+ if image.mode != 'RGB':
161
+ image = image.convert('RGB')
162
+
163
+ # Convert to numpy array and resize
164
+ img_array = np.array(image)
165
+ img_array = cv2.resize(img_array, (256, 256))
166
+ img_array = img_array.astype('float32') / 255.0
167
+
168
+ # Extract all features
169
+ features = extract_features(img_array)
170
+
171
+ # Scale features using the provided scaler
172
+ scaled_features = scaler.transform(features.reshape(1, -1))
173
+
174
+ return scaled_features
175
 
 
 
 
176
  def get_top_predictions(prediction, class_names, top_k=5):
177
  """Get top k predictions with their probabilities"""
 
178
  top_indices = prediction.argsort()[0][-top_k:][::-1]
179
+ return [
 
 
180
  (class_names[i], float(prediction[0][i]) * 100)
181
  for i in top_indices
182
  ]
 
 
183
 
184
  def main():
 
185
  st.title("🪨 Stone Classification")
186
  st.write("Upload an image of a stone to classify its type")
187
 
188
+ # Load model and scaler
189
+ model, scaler = load_model_and_scaler()
190
+ if model is None or scaler is None:
191
+ st.error("Failed to load model or scaler. Please ensure both files exist.")
192
+ return
193
+
194
+ # Initialize session state
195
  if 'predictions' not in st.session_state:
196
  st.session_state.predictions = None
197
 
 
198
  col1, col2 = st.columns(2)
199
 
200
  with col1:
 
202
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
203
 
204
  if uploaded_file is not None:
205
+ try:
206
+ image = Image.open(uploaded_file)
207
+ st.image(image, caption="Uploaded Image", use_column_width=True)
208
+
209
+ with st.spinner('Analyzing image...'):
210
+ processed_image = preprocess_image(image, scaler)
211
+ prediction = model.predict(processed_image, verbose=0)
 
 
 
 
212
 
 
 
213
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
214
+ st.session_state.predictions = get_top_predictions(prediction, class_names)
215
 
216
+ except Exception as e:
217
+ st.error(f"Error processing image: {str(e)}")
 
 
 
 
 
 
218
 
219
  with col2:
220
  st.subheader("Prediction Results")
 
230
  """,
231
  unsafe_allow_html=True
232
  )
233
+
234
  # Display confidence bar
235
  st.progress(top_confidence / 100)
236
+
237
  # Display top 5 predictions
238
  st.markdown("### Top 5 Predictions")
239
  st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
240
+
241
  for class_name, confidence in st.session_state.predictions:
242
  cols = st.columns([2, 6, 2])
243
  with cols[0]:
 
246
  st.progress(confidence / 100)
247
  with cols[2]:
248
  st.write(f"{confidence:.2f}%")
249
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  st.markdown("</div>", unsafe_allow_html=True)
251
  else:
252
  st.info("Upload an image to see the predictions")
253
+
254
  st.markdown("---")
255
  st.markdown("Made with ❤️ using Streamlit")
256