SonFox2920 commited on
Commit
2cc5732
·
verified ·
1 Parent(s): 83eb891

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -99
app.py CHANGED
@@ -4,10 +4,11 @@ import numpy as np
4
  import cv2
5
  from PIL import Image
6
  import io
 
7
 
8
  # Set page config
9
  st.set_page_config(
10
- page_title="Stone Classification",
11
  page_icon="🪨",
12
  layout="wide"
13
  )
@@ -26,96 +27,107 @@ st.markdown("""
26
  text-align: center;
27
  padding: 2rem;
28
  }
29
-
30
  </style>
31
  """, unsafe_allow_html=True)
32
- # .prediction-card {
33
- # padding: 2rem;
34
- # border-radius: 0.5rem;
35
- # background-color: #f0f2f6;
36
- # margin: 1rem 0;
37
- # }
38
- # .top-predictions {
39
- # margin-top: 2rem;
40
- # padding: 1rem;
41
- # background-color: white;
42
- # border-radius: 0.5rem;
43
- # box-shadow: 0 1px 3px rgba(0,0,0,0.12);
44
- # }
45
- # .prediction-bar {
46
- # display: flex;
47
- # align-items: center;
48
- # margin: 0.5rem 0;
49
- # }
50
- # .prediction-label {
51
- # width: 100px;
52
- # font-weight: 500;
53
- # }
54
- @st.cache_resource
55
- def load_model():
56
- """Load the trained model"""
57
- return tf.keras.models.load_model('custom_model.h5')
58
 
59
- def preprocess_image(image):
60
- """Preprocess the uploaded image"""
61
- # # Convert to RGB if needed
62
- # if image.mode != 'RGB':
63
- # image = image.convert('RGB')
64
-
65
- # Convert to numpy array
66
- img_array = np.array(image)
67
 
68
- # # Convert to RGB if needed
69
- # if len(img_array.shape) == 2: # Grayscale
70
- # img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
71
- # elif img_array.shape[2] == 4: # RGBA
72
- # img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
73
 
74
- # # Preprocess image similar to training
75
- # img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV)
76
- # img_hsv[:, :, 2] = cv2.equalizeHist(img_hsv[:, :, 2])
77
- # img_array = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
78
 
79
- # # Adjust brightness
80
- # target_brightness = 150
81
- # current_brightness = np.mean(img_array)
82
- # alpha = target_brightness / (current_brightness + 1e-5)
83
- # img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=0)
 
 
 
 
 
84
 
85
- # # Apply Gaussian blur
86
- # img_array = cv2.GaussianBlur(img_array, (5, 5), 0)
87
 
88
- # Resize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  img_array = cv2.resize(img_array, (256, 256))
90
-
91
- # Normalize
92
  img_array = img_array.astype('float32') / 255.0
93
-
94
  return img_array
95
 
96
  def get_top_predictions(prediction, class_names, top_k=5):
97
  """Get top k predictions with their probabilities"""
98
- # Get indices of top k predictions
99
  top_indices = prediction.argsort()[0][-top_k:][::-1]
100
-
101
- # Get corresponding class names and probabilities
102
  top_predictions = [
103
  (class_names[i], float(prediction[0][i]) * 100)
104
  for i in top_indices
105
  ]
106
-
107
  return top_predictions
108
 
109
  def main():
110
- # Title
111
- st.title("🪨 Stone Classification")
112
- st.write("Upload an image of a stone to classify its type")
113
 
114
- # Initialize session state for prediction if not exists
115
  if 'predictions' not in st.session_state:
116
  st.session_state.predictions = None
117
 
118
- # Create two columns
119
  col1, col2 = st.columns(2)
120
 
121
  with col1:
@@ -123,53 +135,60 @@ def main():
123
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
124
 
125
  if uploaded_file is not None:
126
- # Display uploaded image
127
  image = Image.open(uploaded_file)
128
  st.image(image, caption="Uploaded Image", use_column_width=True)
129
 
130
- with st.spinner('Analyzing image...'):
131
  try:
132
- # Load model
133
- model = load_model()
134
 
135
- # Preprocess image
136
- processed_image = preprocess_image(image)
 
 
137
 
138
- # Make prediction
139
- prediction = model.predict(np.expand_dims(processed_image, axis=0))
 
 
 
 
 
 
 
140
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
 
141
 
142
- # Get top 5 predictions
143
- top_predictions = get_top_predictions(prediction, class_names)
 
 
 
 
 
144
 
145
  # Store in session state
146
- st.session_state.predictions = top_predictions
147
 
148
  except Exception as e:
149
- st.error(f"Error during prediction: {str(e)}")
150
 
151
  with col2:
152
- st.subheader("Prediction Results")
153
  if st.session_state.predictions is not None:
154
- # Create a card-like container for results
155
- results_container = st.container()
156
- with results_container:
157
- # Display main prediction
158
- st.markdown("<div class='prediction-card'>", unsafe_allow_html=True)
159
- top_class, top_confidence = st.session_state.predictions[0]
160
- st.markdown(f"### Primary Prediction: Grade {top_class}")
161
- st.markdown(f"### Confidence: {top_confidence:.2f}%")
162
- st.markdown("</div>", unsafe_allow_html=True)
163
 
164
- # Display confidence bar for top prediction
 
 
 
165
  st.progress(top_confidence / 100)
166
 
167
- # Display top 5 predictions
168
- st.markdown("### Top 5 Predictions")
169
- st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
170
-
171
- # Create a Streamlit container for the predictions
172
- for class_name, confidence in st.session_state.predictions:
173
  col_label, col_bar, col_value = st.columns([2, 6, 2])
174
  with col_label:
175
  st.write(f"Grade {class_name}")
@@ -178,11 +197,9 @@ def main():
178
  with col_value:
179
  st.write(f"{confidence:.2f}%")
180
 
181
- st.markdown("</div>", unsafe_allow_html=True)
182
  else:
183
- st.info("Upload an image and click 'Predict' to see the results")
184
-
185
- # Footer
186
- st.markdown("---")
187
  if __name__ == "__main__":
188
  main()
 
4
  import cv2
5
  from PIL import Image
6
  import io
7
+ import torch
8
 
9
  # Set page config
10
  st.set_page_config(
11
+ page_title="Stone Detection & Classification",
12
  page_icon="🪨",
13
  layout="wide"
14
  )
 
27
  text-align: center;
28
  padding: 2rem;
29
  }
 
30
  </style>
31
  """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ def resize_to_square(image):
34
+ """Resize image to square while maintaining aspect ratio"""
35
+ size = max(image.shape[0], image.shape[1])
36
+ new_img = np.zeros((size, size, 3), dtype=np.uint8)
 
 
 
 
37
 
38
+ # Calculate position to paste original image
39
+ x_center = (size - image.shape[1]) // 2
40
+ y_center = (size - image.shape[0]) // 2
 
 
41
 
42
+ # Copy the image into center of result image
43
+ new_img[y_center:y_center+image.shape[0],
44
+ x_center:x_center+image.shape[1]] = image
 
45
 
46
+ return new_img
47
+
48
+ @st.cache_resource
49
+ def load_models():
50
+ """Load both object detection and classification models"""
51
+ # Load object detection model
52
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
53
+ object_detection_model = torch.load("fasterrcnn_resnet50_fpn_270824.pth", map_location=device)
54
+ object_detection_model.to(device)
55
+ object_detection_model.eval()
56
 
57
+ # Load classification model
58
+ classification_model = tf.keras.models.load_model('custom_model.h5')
59
 
60
+ return object_detection_model, classification_model, device
61
+
62
+ def perform_object_detection(image, model, device):
63
+ original_size = image.size
64
+ target_size = (256, 256)
65
+ frame_resized = cv2.resize(np.array(image), dsize=target_size, interpolation=cv2.INTER_AREA)
66
+ frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_RGB2BGR).astype(np.float32)
67
+ frame_rgb /= 255.0
68
+ frame_rgb = frame_rgb.transpose(2, 0, 1)
69
+ frame_rgb = torch.from_numpy(frame_rgb).float().unsqueeze(0).to(device)
70
+
71
+ with torch.no_grad():
72
+ outputs = model(frame_rgb)
73
+
74
+ boxes = outputs[0]['boxes'].cpu().detach().numpy().astype(np.int32)
75
+ labels = outputs[0]['labels'].cpu().detach().numpy().astype(np.int32)
76
+ scores = outputs[0]['scores'].cpu().detach().numpy()
77
+
78
+ result_image = frame_resized.copy()
79
+ cropped_images = []
80
+ detected_boxes = []
81
+
82
+ for i in range(len(boxes)):
83
+ if scores[i] >= 0.75:
84
+ x1, y1, x2, y2 = boxes[i]
85
+ if (int(labels[i])-1) == 1 or (int(labels[i])-1) == 0:
86
+ color = (0, 0, 255)
87
+ label_text = 'Flame stone surface'
88
+
89
+ # Scale coordinates to original image size
90
+ original_h, original_w = original_size[::-1]
91
+ scale_h, scale_w = original_h / target_size[0], original_w / target_size[1]
92
+ x1_orig, y1_orig = int(x1 * scale_w), int(y1 * scale_h)
93
+ x2_orig, y2_orig = int(x2 * scale_w), int(y2 * scale_h)
94
+
95
+ # Crop and process detected region
96
+ cropped_image = np.array(image)[y1_orig:y2_orig, x1_orig:x2_orig]
97
+ resized_crop = resize_to_square(cropped_image)
98
+ cropped_images.append(resized_crop)
99
+ detected_boxes.append((x1, y1, x2, y2))
100
+
101
+ # Draw bounding box
102
+ cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 3)
103
+ cv2.putText(result_image, label_text, (x1, y1 - 10),
104
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
105
+
106
+ return Image.fromarray(result_image), cropped_images, detected_boxes
107
+
108
+ def preprocess_image(image):
109
+ """Preprocess the image for classification"""
110
+ img_array = np.array(image)
111
  img_array = cv2.resize(img_array, (256, 256))
 
 
112
  img_array = img_array.astype('float32') / 255.0
 
113
  return img_array
114
 
115
  def get_top_predictions(prediction, class_names, top_k=5):
116
  """Get top k predictions with their probabilities"""
 
117
  top_indices = prediction.argsort()[0][-top_k:][::-1]
 
 
118
  top_predictions = [
119
  (class_names[i], float(prediction[0][i]) * 100)
120
  for i in top_indices
121
  ]
 
122
  return top_predictions
123
 
124
  def main():
125
+ st.title("🪨 Stone Detection & Classification")
126
+ st.write("Upload an image to detect and classify stone surfaces")
 
127
 
 
128
  if 'predictions' not in st.session_state:
129
  st.session_state.predictions = None
130
 
 
131
  col1, col2 = st.columns(2)
132
 
133
  with col1:
 
135
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
136
 
137
  if uploaded_file is not None:
 
138
  image = Image.open(uploaded_file)
139
  st.image(image, caption="Uploaded Image", use_column_width=True)
140
 
141
+ with st.spinner('Processing image...'):
142
  try:
143
+ # Load both models
144
+ object_detection_model, classification_model, device = load_models()
145
 
146
+ # Perform object detection
147
+ result_image, cropped_images, detected_boxes = perform_object_detection(
148
+ image, object_detection_model, device
149
+ )
150
 
151
+ if not cropped_images:
152
+ st.warning("No stone surfaces detected in the image")
153
+ return
154
+
155
+ # Display detection results
156
+ st.subheader("Detection Results")
157
+ st.image(result_image, caption="Detected Stone Surfaces", use_column_width=True)
158
+
159
+ # Process each detected region
160
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
161
+ all_predictions = []
162
 
163
+ for idx, cropped_image in enumerate(cropped_images):
164
+ processed_image = preprocess_image(cropped_image)
165
+ prediction = classification_model.predict(
166
+ np.expand_dims(processed_image, axis=0)
167
+ )
168
+ top_predictions = get_top_predictions(prediction, class_names)
169
+ all_predictions.append(top_predictions)
170
 
171
  # Store in session state
172
+ st.session_state.predictions = all_predictions
173
 
174
  except Exception as e:
175
+ st.error(f"Error during processing: {str(e)}")
176
 
177
  with col2:
178
+ st.subheader("Classification Results")
179
  if st.session_state.predictions is not None:
180
+ for idx, predictions in enumerate(st.session_state.predictions):
181
+ st.markdown(f"### Region {idx + 1}")
 
 
 
 
 
 
 
182
 
183
+ # Display main prediction
184
+ top_class, top_confidence = predictions[0]
185
+ st.markdown(f"**Primary Prediction: Grade {top_class}**")
186
+ st.markdown(f"**Confidence: {top_confidence:.2f}%**")
187
  st.progress(top_confidence / 100)
188
 
189
+ # Display all predictions for this region
190
+ st.markdown("**Top 5 Predictions**")
191
+ for class_name, confidence in predictions:
 
 
 
192
  col_label, col_bar, col_value = st.columns([2, 6, 2])
193
  with col_label:
194
  st.write(f"Grade {class_name}")
 
197
  with col_value:
198
  st.write(f"{confidence:.2f}%")
199
 
200
+ st.markdown("---")
201
  else:
202
+ st.info("Upload an image to see detection and classification results")
203
+
 
 
204
  if __name__ == "__main__":
205
  main()