SonFox2920 commited on
Commit
0623da1
·
verified ·
1 Parent(s): 36422a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -150
app.py CHANGED
@@ -8,6 +8,11 @@ from tensorflow.keras.applications import EfficientNetB0
8
  from tensorflow.keras.applications.efficientnet import preprocess_input
9
  import joblib
10
  import io
 
 
 
 
 
11
 
12
  # Set page config
13
  st.set_page_config(
@@ -39,183 +44,80 @@ st.markdown("""
39
  border-radius: 0.5rem;
40
  box-shadow: 0 1px 3px rgba(0,0,0,0.12);
41
  }
 
 
 
 
 
 
42
  </style>
43
  """, unsafe_allow_html=True)
44
 
45
- # Cache the model loading
46
- @st.cache_resource
47
- def load_model_and_scaler():
48
- """Load the trained model and scaler"""
 
 
 
49
  try:
50
- model = tf.keras.models.load_model('mlp_model.h5')
51
- # Tải scaler đã lưu
52
- scaler = joblib.load('standard_scaler.pkl')
53
- return model, scaler
 
 
 
 
 
 
 
 
 
 
 
54
  except Exception as e:
55
- st.error(f"Error loading model or scaler: {str(e)}")
56
- return None, None
57
-
58
- def color_histogram(image, bins=16):
59
- """Calculate color histogram features"""
60
- hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
61
- hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
62
- hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
63
-
64
- hist_r = hist_r / (np.sum(hist_r) + 1e-7)
65
- hist_g = hist_g / (np.sum(hist_g) + 1e-7)
66
- hist_b = hist_b / (np.sum(hist_b) + 1e-7)
67
-
68
- return np.concatenate([hist_r, hist_g, hist_b])
69
-
70
- def color_moments(image):
71
- """Calculate color moments features"""
72
- img = image.astype(np.float32) / 255.0
73
- moments = []
74
-
75
- for i in range(3):
76
- channel = img[:,:,i]
77
- mean = np.mean(channel)
78
- std = np.std(channel) + 1e-7
79
- skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
80
- moments.extend([mean, std, skewness])
81
-
82
- return np.array(moments)
83
-
84
- def dominant_color_descriptor(image, k=3):
85
- """Calculate dominant color descriptor"""
86
- pixels = image.reshape(-1, 3).astype(np.float32)
87
-
88
- criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
89
- flags = cv2.KMEANS_RANDOM_CENTERS
90
-
91
- try:
92
- _, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, flags)
93
- unique, counts = np.unique(labels, return_counts=True)
94
- percentages = counts / len(labels)
95
- return np.concatenate([centers.flatten(), percentages])
96
- except Exception:
97
- return np.zeros(k * 4)
98
-
99
- def color_coherence_vector(image, k=3):
100
- """Calculate color coherence vector"""
101
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
102
- gray = np.uint8(gray)
103
-
104
- _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
105
- num_labels, labels = cv2.connectedComponents(binary)
106
-
107
- ccv = []
108
- for i in range(1, min(k+1, num_labels)):
109
- region_mask = (labels == i)
110
- total_pixels = np.sum(region_mask)
111
- ccv.extend([total_pixels, total_pixels])
112
-
113
- ccv.extend([0] * (2 * k - len(ccv)))
114
- return np.array(ccv[:2*k])
115
-
116
- @st.cache_resource
117
- def create_vit_feature_extractor():
118
- """Create and cache the ViT feature extractor"""
119
- input_shape = (256, 256, 3)
120
- inputs = layers.Input(shape=input_shape)
121
- x = layers.Lambda(preprocess_input)(inputs)
122
-
123
- base_model = EfficientNetB0(
124
- include_top=False,
125
- weights='imagenet',
126
- input_tensor=x
127
- )
128
-
129
- x = layers.GlobalAveragePooling2D()(base_model.output)
130
- return models.Model(inputs=inputs, outputs=x)
131
-
132
- def extract_features(image):
133
- """Extract all features from an image"""
134
- # Traditional features
135
- hist_features = color_histogram(image)
136
- moment_features = color_moments(image)
137
- dominant_features = dominant_color_descriptor(image)
138
- ccv_features = color_coherence_vector(image)
139
-
140
- traditional_features = np.concatenate([
141
- hist_features,
142
- moment_features,
143
- dominant_features,
144
- ccv_features
145
- ])
146
-
147
- # Deep features using ViT
148
- feature_extractor = create_vit_feature_extractor()
149
- vit_features = feature_extractor.predict(
150
- np.expand_dims(image, axis=0),
151
- verbose=0
152
- )
153
-
154
- # Combine all features
155
- return np.concatenate([traditional_features, vit_features.flatten()])
156
-
157
- def preprocess_image(image, scaler):
158
- """Preprocess the uploaded image"""
159
- # Convert to RGB if needed
160
- if image.mode != 'RGB':
161
- image = image.convert('RGB')
162
-
163
- # Convert to numpy array and resize
164
- img_array = np.array(image)
165
- img_array = cv2.resize(img_array, (256, 256))
166
- img_array = img_array.astype('float32') / 255.0
167
-
168
- # Extract all features
169
- features = extract_features(img_array)
170
-
171
- # Scale features using the provided scaler
172
- scaled_features = scaler.transform(features.reshape(1, -1))
173
-
174
- return scaled_features
175
-
176
- def get_top_predictions(prediction, class_names, top_k=5):
177
- """Get top k predictions with their probabilities"""
178
- top_indices = prediction.argsort()[0][-top_k:][::-1]
179
- return [
180
- (class_names[i], float(prediction[0][i]) * 100)
181
- for i in top_indices
182
- ]
183
 
184
  def main():
185
  st.title("🪨 Stone Classification")
186
  st.write("Upload an image of a stone to classify its type")
187
-
188
  # Load model and scaler
189
  model, scaler = load_model_and_scaler()
190
  if model is None or scaler is None:
191
  st.error("Failed to load model or scaler. Please ensure both files exist.")
192
  return
193
-
194
  # Initialize session state
195
  if 'predictions' not in st.session_state:
196
  st.session_state.predictions = None
197
-
 
 
198
  col1, col2 = st.columns(2)
199
-
200
  with col1:
201
  st.subheader("Upload Image")
202
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
203
-
204
  if uploaded_file is not None:
205
  try:
206
  image = Image.open(uploaded_file)
207
  st.image(image, caption="Uploaded Image", use_column_width=True)
208
-
 
209
  with st.spinner('Analyzing image...'):
210
  processed_image = preprocess_image(image, scaler)
211
  prediction = model.predict(processed_image, verbose=0)
212
-
213
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
214
  st.session_state.predictions = get_top_predictions(prediction, class_names)
215
-
216
  except Exception as e:
217
  st.error(f"Error processing image: {str(e)}")
218
-
219
  with col2:
220
  st.subheader("Prediction Results")
221
  if st.session_state.predictions:
@@ -230,14 +132,14 @@ def main():
230
  """,
231
  unsafe_allow_html=True
232
  )
233
-
234
  # Display confidence bar
235
  st.progress(top_confidence / 100)
236
-
237
  # Display top 5 predictions
238
  st.markdown("### Top 5 Predictions")
239
  st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
240
-
241
  for class_name, confidence in st.session_state.predictions:
242
  cols = st.columns([2, 6, 2])
243
  with cols[0]:
@@ -246,11 +148,50 @@ def main():
246
  st.progress(confidence / 100)
247
  with cols[2]:
248
  st.write(f"{confidence:.2f}%")
249
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  st.markdown("</div>", unsafe_allow_html=True)
251
  else:
252
  st.info("Upload an image to see the predictions")
253
-
254
  st.markdown("---")
255
  st.markdown("Made with ❤️ using Streamlit")
256
 
 
8
  from tensorflow.keras.applications.efficientnet import preprocess_input
9
  import joblib
10
  import io
11
+ import os
12
+ from google.oauth2.credentials import Credentials
13
+ from googleapiclient.discovery import build
14
+ from googleapiclient.http import MediaFileUpload
15
+ from google.oauth2 import service_account
16
 
17
  # Set page config
18
  st.set_page_config(
 
44
  border-radius: 0.5rem;
45
  box-shadow: 0 1px 3px rgba(0,0,0,0.12);
46
  }
47
+ .survey-card {
48
+ padding: 1rem;
49
+ background-color: #f0f2f6;
50
+ border-radius: 0.5rem;
51
+ margin-top: 1rem;
52
+ }
53
  </style>
54
  """, unsafe_allow_html=True)
55
 
56
+ from mega import Mega
57
+
58
+ # Đăng nhập vào tài khoản Mega
59
+ def upload_to_mega(file_path, folder_name):
60
+ """
61
+ Upload file to a specific folder on Mega.nz
62
+ """
63
  try:
64
+ # Đăng nhập vào tài khoản Mega
65
+ mega = Mega()
66
+ m = mega.login(os.getenv('EMAIL'), os.getenv('PASSWORD'))
67
+
68
+ # Tìm thư mục đích
69
+ folder = m.find(folder_name)
70
+
71
+ if not folder:
72
+ # Nếu thư mục không tồn tại, hiển thị thông báo lỗi
73
+ return f"Thư mục '{folder_name}' không tồn tại!"
74
+
75
+ # Tải tệp lên thư mục
76
+ file = m.upload(file_path, folder[0])
77
+ return f"Upload thành công! Link: {m.get_upload_link(file)}"
78
+
79
  except Exception as e:
80
+ return f"Lỗi khi tải lên Mega: {str(e)}"
81
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def main():
84
  st.title("🪨 Stone Classification")
85
  st.write("Upload an image of a stone to classify its type")
86
+
87
  # Load model and scaler
88
  model, scaler = load_model_and_scaler()
89
  if model is None or scaler is None:
90
  st.error("Failed to load model or scaler. Please ensure both files exist.")
91
  return
92
+
93
  # Initialize session state
94
  if 'predictions' not in st.session_state:
95
  st.session_state.predictions = None
96
+ if 'uploaded_image' not in st.session_state:
97
+ st.session_state.uploaded_image = None
98
+
99
  col1, col2 = st.columns(2)
100
+
101
  with col1:
102
  st.subheader("Upload Image")
103
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
104
+
105
  if uploaded_file is not None:
106
  try:
107
  image = Image.open(uploaded_file)
108
  st.image(image, caption="Uploaded Image", use_column_width=True)
109
+ st.session_state.uploaded_image = image
110
+
111
  with st.spinner('Analyzing image...'):
112
  processed_image = preprocess_image(image, scaler)
113
  prediction = model.predict(processed_image, verbose=0)
114
+
115
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
116
  st.session_state.predictions = get_top_predictions(prediction, class_names)
117
+
118
  except Exception as e:
119
  st.error(f"Error processing image: {str(e)}")
120
+
121
  with col2:
122
  st.subheader("Prediction Results")
123
  if st.session_state.predictions:
 
132
  """,
133
  unsafe_allow_html=True
134
  )
135
+
136
  # Display confidence bar
137
  st.progress(top_confidence / 100)
138
+
139
  # Display top 5 predictions
140
  st.markdown("### Top 5 Predictions")
141
  st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
142
+
143
  for class_name, confidence in st.session_state.predictions:
144
  cols = st.columns([2, 6, 2])
145
  with cols[0]:
 
148
  st.progress(confidence / 100)
149
  with cols[2]:
150
  st.write(f"{confidence:.2f}%")
151
+
152
+ st.markdown("</div>", unsafe_allow_html=True)
153
+
154
+ # User Survey
155
+ st.markdown("<div class='survey-card'>", unsafe_allow_html=True)
156
+ st.markdown("### Model Accuracy Survey")
157
+ st.write("Mô hình có dự đoán chính xác màu sắc của đá trong ảnh này không?")
158
+
159
+ # Accuracy Confirmation
160
+ accuracy = st.radio(
161
+ "Đánh giá độ chính xác",
162
+ ["Chọn", "Chính xác", "Không chính xác"],
163
+ index=0
164
+ )
165
+
166
+ if accuracy == "Không chính xác":
167
+ # Color input for incorrect prediction
168
+ correct_color = st.text_input(
169
+ "Vui lòng nhập màu sắc chính xác của đá:",
170
+ help="Ví dụ: 10, 9.7, 9.5, 9.2, v.v."
171
+ )
172
+
173
+ if st.button("Gửi phản hồi và tải ảnh"):
174
+ if correct_color and st.session_state.uploaded_image:
175
+ # Save the image temporarily
176
+ temp_image_path = f"temp_image_{hash(uploaded_file.name)}.png"
177
+ st.session_state.uploaded_image.save(temp_image_path)
178
+
179
+ # Upload to Mega.nz
180
+ upload_result = upload_to_mega(temp_image_path, correct_color)
181
+ if "Upload thành công" in upload_result:
182
+ st.success(upload_result)
183
+ else:
184
+ st.error(upload_result)
185
+
186
+ # Clean up temporary file
187
+ os.remove(temp_image_path)
188
+ else:
189
+ st.warning("Vui lòng nhập màu sắc chính xác")
190
+
191
  st.markdown("</div>", unsafe_allow_html=True)
192
  else:
193
  st.info("Upload an image to see the predictions")
194
+
195
  st.markdown("---")
196
  st.markdown("Made with ❤️ using Streamlit")
197