SonFox2920 commited on
Commit
af547ff
·
verified ·
1 Parent(s): ad617e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -25
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  import tensorflow as tf
3
  import numpy as np
@@ -5,7 +6,19 @@ import cv2
5
  from PIL import Image
6
  import io
7
  import torch
8
-
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Set page config
10
  st.set_page_config(
11
  page_title="Stone Detection & Classification",
@@ -13,6 +26,10 @@ st.set_page_config(
13
  layout="wide"
14
  )
15
 
 
 
 
 
16
  # Custom CSS to improve the appearance
17
  st.markdown("""
18
  <style>
@@ -29,20 +46,55 @@ st.markdown("""
29
  }
30
  </style>
31
  """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def resize_to_square(image):
34
  """Resize image to square while maintaining aspect ratio"""
35
  size = max(image.shape[0], image.shape[1])
36
  new_img = np.zeros((size, size, 3), dtype=np.uint8)
37
-
38
  # Calculate position to paste original image
39
  x_center = (size - image.shape[1]) // 2
40
  y_center = (size - image.shape[0]) // 2
41
-
42
  # Copy the image into center of result image
43
- new_img[y_center:y_center+image.shape[0],
44
  x_center:x_center+image.shape[1]] = image
45
-
46
  return new_img
47
 
48
  @st.cache_resource
@@ -53,17 +105,17 @@ def load_models():
53
  object_detection_model = torch.load("fasterrcnn_resnet50_fpn_090824.pth", map_location=device)
54
  object_detection_model.to(device)
55
  object_detection_model.eval()
56
-
57
  # Load classification model
58
  classification_model = tf.keras.models.load_model('custom_model.h5')
59
-
60
  return object_detection_model, classification_model, device
61
 
62
  def perform_object_detection(image, model, device):
63
  original_size = image.size
64
  target_size = (256, 256)
65
  frame_resized = cv2.resize(np.array(image), dsize=target_size, interpolation=cv2.INTER_AREA)
66
- frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_RGB2BGR).astype(np.float32)
67
  frame_rgb /= 255.0
68
  frame_rgb = frame_rgb.transpose(2, 0, 1)
69
  frame_rgb = torch.from_numpy(frame_rgb).float().unsqueeze(0).to(device)
@@ -91,16 +143,22 @@ def perform_object_detection(image, model, device):
91
  scale_h, scale_w = original_h / target_size[0], original_w / target_size[1]
92
  x1_orig, y1_orig = int(x1 * scale_w), int(y1 * scale_h)
93
  x2_orig, y2_orig = int(x2 * scale_w), int(y2 * scale_h)
94
-
95
  # Crop and process detected region
96
  cropped_image = np.array(image)[y1_orig:y2_orig, x1_orig:x2_orig]
 
 
 
 
 
 
97
  resized_crop = resize_to_square(cropped_image)
98
  cropped_images.append(resized_crop)
99
  detected_boxes.append((x1, y1, x2, y2))
100
 
101
  # Draw bounding box
102
  cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 3)
103
- cv2.putText(result_image, label_text, (x1, y1 - 10),
104
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
105
 
106
  return Image.fromarray(result_image), cropped_images, detected_boxes
@@ -124,16 +182,16 @@ def get_top_predictions(prediction, class_names, top_k=5):
124
  def main():
125
  st.title("🪨 Stone Detection & Classification")
126
  st.write("Upload an image to detect and classify stone surfaces")
127
-
128
  if 'predictions' not in st.session_state:
129
  st.session_state.predictions = None
130
-
131
  col1, col2 = st.columns(2)
132
-
133
  with col1:
134
  st.subheader("Upload Image")
135
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
136
-
137
  if uploaded_file is not None:
138
  image = Image.open(uploaded_file)
139
  st.image(image, caption="Uploaded Image", use_column_width=True)
@@ -142,24 +200,24 @@ def main():
142
  try:
143
  # Load both models
144
  object_detection_model, classification_model, device = load_models()
145
-
146
  # Perform object detection
147
  result_image, cropped_images, detected_boxes = perform_object_detection(
148
  image, object_detection_model, device
149
  )
150
-
151
  if not cropped_images:
152
  st.warning("No stone surfaces detected in the image")
153
  return
154
-
155
  # Display detection results
156
  st.subheader("Detection Results")
157
  st.image(result_image, caption="Detected Stone Surfaces", use_column_width=True)
158
-
159
  # Process each detected region
160
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
161
  all_predictions = []
162
-
163
  for idx, cropped_image in enumerate(cropped_images):
164
  processed_image = preprocess_image(cropped_image)
165
  prediction = classification_model.predict(
@@ -167,25 +225,25 @@ def main():
167
  )
168
  top_predictions = get_top_predictions(prediction, class_names)
169
  all_predictions.append(top_predictions)
170
-
171
  # Store in session state
172
  st.session_state.predictions = all_predictions
173
-
174
  except Exception as e:
175
  st.error(f"Error during processing: {str(e)}")
176
-
177
  with col2:
178
  st.subheader("Classification Results")
179
  if st.session_state.predictions is not None:
180
  for idx, predictions in enumerate(st.session_state.predictions):
181
  st.markdown(f"### Region {idx + 1}")
182
-
183
  # Display main prediction
184
  top_class, top_confidence = predictions[0]
185
  st.markdown(f"**Primary Prediction: Grade {top_class}**")
186
  st.markdown(f"**Confidence: {top_confidence:.2f}%**")
187
  st.progress(top_confidence / 100)
188
-
189
  # Display all predictions for this region
190
  st.markdown("**Top 5 Predictions**")
191
  for class_name, confidence in predictions:
@@ -196,8 +254,50 @@ def main():
196
  st.progress(confidence / 100)
197
  with col_value:
198
  st.write(f"{confidence:.2f}%")
199
-
200
  st.markdown("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  else:
202
  st.info("Upload an image to see detection and classification results")
203
 
 
1
+ %%writefile app.py
2
  import streamlit as st
3
  import tensorflow as tf
4
  import numpy as np
 
6
  from PIL import Image
7
  import io
8
  import torch
9
+ import cloudinary
10
+ import cloudinary.uploader
11
+ from cloudinary.utils import cloudinary_url
12
+ import os
13
+ import random
14
+ import string
15
+ # Cloudinary Configuration
16
+ cloudinary.config(
17
+ cloud_name = os.getenv("CLOUD"),
18
+ api_key = os.getenv("API"),
19
+ api_secret = os.getenv("SECRET"),
20
+ secure=True
21
+ )
22
  # Set page config
23
  st.set_page_config(
24
  page_title="Stone Detection & Classification",
 
26
  layout="wide"
27
  )
28
 
29
+ def generate_random_filename(extension="png"):
30
+ random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=8))
31
+ return f"temp_image_{random_string}.{extension}"
32
+
33
  # Custom CSS to improve the appearance
34
  st.markdown("""
35
  <style>
 
46
  }
47
  </style>
48
  """, unsafe_allow_html=True)
49
+ def upload_to_cloudinary(file_path, label):
50
+ """
51
+ Upload file to Cloudinary with specified label as folder
52
+ """
53
+ try:
54
+ # Upload to Cloudinary
55
+ upload_result = cloudinary.uploader.upload(
56
+ file_path,
57
+ folder=label,
58
+ public_id=f"{label}_{os.path.basename(file_path)}"
59
+ )
60
+
61
+ # Generate optimized URLs
62
+ optimize_url, _ = cloudinary_url(
63
+ upload_result['public_id'],
64
+ fetch_format="auto",
65
+ quality="auto"
66
+ )
67
+
68
+ auto_crop_url, _ = cloudinary_url(
69
+ upload_result['public_id'],
70
+ width=500,
71
+ height=500,
72
+ crop="auto",
73
+ gravity="auto"
74
+ )
75
+
76
+ return {
77
+ "upload_result": upload_result,
78
+ "optimize_url": optimize_url,
79
+ "auto_crop_url": auto_crop_url
80
+ }
81
+
82
+ except Exception as e:
83
+ return f"Error uploading to Cloudinary: {str(e)}"
84
 
85
  def resize_to_square(image):
86
  """Resize image to square while maintaining aspect ratio"""
87
  size = max(image.shape[0], image.shape[1])
88
  new_img = np.zeros((size, size, 3), dtype=np.uint8)
89
+
90
  # Calculate position to paste original image
91
  x_center = (size - image.shape[1]) // 2
92
  y_center = (size - image.shape[0]) // 2
93
+
94
  # Copy the image into center of result image
95
+ new_img[y_center:y_center+image.shape[0],
96
  x_center:x_center+image.shape[1]] = image
97
+
98
  return new_img
99
 
100
  @st.cache_resource
 
105
  object_detection_model = torch.load("fasterrcnn_resnet50_fpn_090824.pth", map_location=device)
106
  object_detection_model.to(device)
107
  object_detection_model.eval()
108
+
109
  # Load classification model
110
  classification_model = tf.keras.models.load_model('custom_model.h5')
111
+
112
  return object_detection_model, classification_model, device
113
 
114
  def perform_object_detection(image, model, device):
115
  original_size = image.size
116
  target_size = (256, 256)
117
  frame_resized = cv2.resize(np.array(image), dsize=target_size, interpolation=cv2.INTER_AREA)
118
+ frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_RGB2BGR).astype(np.float32)
119
  frame_rgb /= 255.0
120
  frame_rgb = frame_rgb.transpose(2, 0, 1)
121
  frame_rgb = torch.from_numpy(frame_rgb).float().unsqueeze(0).to(device)
 
143
  scale_h, scale_w = original_h / target_size[0], original_w / target_size[1]
144
  x1_orig, y1_orig = int(x1 * scale_w), int(y1 * scale_h)
145
  x2_orig, y2_orig = int(x2 * scale_w), int(y2 * scale_h)
146
+
147
  # Crop and process detected region
148
  cropped_image = np.array(image)[y1_orig:y2_orig, x1_orig:x2_orig]
149
+
150
+ # Check if image has 4 channels (RGBA), convert to RGB
151
+ if cropped_image.shape[-1] == 4:
152
+ cropped_image = cv2.cvtColor(cropped_image, cv2.COLOR_RGBA2RGB)
153
+
154
+ # Resize cropped image
155
  resized_crop = resize_to_square(cropped_image)
156
  cropped_images.append(resized_crop)
157
  detected_boxes.append((x1, y1, x2, y2))
158
 
159
  # Draw bounding box
160
  cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 3)
161
+ cv2.putText(result_image, label_text, (x1, y1 - 10),
162
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
163
 
164
  return Image.fromarray(result_image), cropped_images, detected_boxes
 
182
  def main():
183
  st.title("🪨 Stone Detection & Classification")
184
  st.write("Upload an image to detect and classify stone surfaces")
185
+
186
  if 'predictions' not in st.session_state:
187
  st.session_state.predictions = None
188
+
189
  col1, col2 = st.columns(2)
190
+
191
  with col1:
192
  st.subheader("Upload Image")
193
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
194
+
195
  if uploaded_file is not None:
196
  image = Image.open(uploaded_file)
197
  st.image(image, caption="Uploaded Image", use_column_width=True)
 
200
  try:
201
  # Load both models
202
  object_detection_model, classification_model, device = load_models()
203
+
204
  # Perform object detection
205
  result_image, cropped_images, detected_boxes = perform_object_detection(
206
  image, object_detection_model, device
207
  )
208
+
209
  if not cropped_images:
210
  st.warning("No stone surfaces detected in the image")
211
  return
212
+
213
  # Display detection results
214
  st.subheader("Detection Results")
215
  st.image(result_image, caption="Detected Stone Surfaces", use_column_width=True)
216
+
217
  # Process each detected region
218
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
219
  all_predictions = []
220
+
221
  for idx, cropped_image in enumerate(cropped_images):
222
  processed_image = preprocess_image(cropped_image)
223
  prediction = classification_model.predict(
 
225
  )
226
  top_predictions = get_top_predictions(prediction, class_names)
227
  all_predictions.append(top_predictions)
228
+
229
  # Store in session state
230
  st.session_state.predictions = all_predictions
231
+
232
  except Exception as e:
233
  st.error(f"Error during processing: {str(e)}")
234
+
235
  with col2:
236
  st.subheader("Classification Results")
237
  if st.session_state.predictions is not None:
238
  for idx, predictions in enumerate(st.session_state.predictions):
239
  st.markdown(f"### Region {idx + 1}")
240
+
241
  # Display main prediction
242
  top_class, top_confidence = predictions[0]
243
  st.markdown(f"**Primary Prediction: Grade {top_class}**")
244
  st.markdown(f"**Confidence: {top_confidence:.2f}%**")
245
  st.progress(top_confidence / 100)
246
+
247
  # Display all predictions for this region
248
  st.markdown("**Top 5 Predictions**")
249
  for class_name, confidence in predictions:
 
254
  st.progress(confidence / 100)
255
  with col_value:
256
  st.write(f"{confidence:.2f}%")
257
+
258
  st.markdown("---")
259
+ st.markdown("</div>", unsafe_allow_html=True)
260
+
261
+ # User Confirmation Section
262
+ st.markdown("### Xác nhận độ chính xác của mô hình")
263
+ st.write("Giúp chúng tôi cải thiện mô hình bằng cách xác nhận độ chính xác của dự đoán.")
264
+
265
+ # Accuracy Radio Button
266
+ accuracy_option = st.radio(
267
+ "Dự đoán có chính xác không?",
268
+ ["Chọn", "Chính xác", "Không chính xác"],
269
+ index=0,
270
+ key=f"accuracy_radio_{idx}"
271
+ )
272
+ if accuracy_option == "Không chính xác":
273
+ # Input for correct grade
274
+ correct_grade = st.selectbox(
275
+ "Chọn màu đá đúng:",
276
+ ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7'],
277
+ index=None,
278
+ placeholder="Chọn màu đúng",
279
+ key=f"selectbox_correct_grade_{idx}"
280
+ )
281
+
282
+ # Chỉ thực hiện khi người dùng đã chọn giá trị trong selectbox
283
+ if correct_grade:
284
+ st.info(f"Đã chọn màu đúng: {correct_grade}")
285
+
286
+ # Resize hình ảnh xuống 256x256
287
+ resized_image = Image.fromarray(cropped_image).resize((256, 256))
288
+ temp_image_path = generate_random_filename()
289
+
290
+ # Lưu tệp resize tạm thời
291
+ resized_image.save(temp_image_path)
292
+
293
+ # Tải ảnh lên Cloudinary
294
+ cloudinary_result = upload_to_cloudinary(temp_image_path, correct_grade)
295
+
296
+ if isinstance(cloudinary_result, dict):
297
+ st.success(f"Hình ảnh đã được tải lên thành công cho màu {correct_grade}")
298
+ st.write(f"URL công khai: {cloudinary_result['upload_result']['secure_url']}")
299
+ else:
300
+ st.error(cloudinary_result)
301
  else:
302
  st.info("Upload an image to see detection and classification results")
303