Spaces:
Sleeping
Sleeping
File size: 7,916 Bytes
29bcdf2 435eb8d 2cc5732 435eb8d c2ddd50 2cc5732 c2ddd50 29bcdf2 8be899f c2ddd50 2cc5732 c2ddd50 2cc5732 c2ddd50 2cc5732 c2ddd50 2cc5732 1cede7c 2cc5732 c2ddd50 2cc5732 c2ddd50 2cc5732 c2ddd50 d58b258 2cc5732 c2ddd50 d58b258 c2ddd50 d58b258 c2ddd50 d58b258 c2ddd50 d58b258 c2ddd50 2cc5732 c2ddd50 2cc5732 c2ddd50 2cc5732 c2ddd50 2cc5732 d58b258 2cc5732 c2ddd50 2cc5732 c2ddd50 2cc5732 c2ddd50 2cc5732 c2ddd50 d58b258 2cc5732 c2ddd50 2cc5732 c2ddd50 2cc5732 c2ddd50 2cc5732 c2ddd50 2cc5732 c2ddd50 2cc5732 29bcdf2 8be899f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import streamlit as st
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
import io
import torch
# Set page config
st.set_page_config(
page_title="Stone Detection & Classification",
page_icon="🪨",
layout="wide"
)
# Custom CSS to improve the appearance
st.markdown("""
<style>
.main {
padding: 2rem;
}
.stButton>button {
width: 100%;
margin-top: 1rem;
}
.upload-text {
text-align: center;
padding: 2rem;
}
</style>
""", unsafe_allow_html=True)
def resize_to_square(image):
"""Resize image to square while maintaining aspect ratio"""
size = max(image.shape[0], image.shape[1])
new_img = np.zeros((size, size, 3), dtype=np.uint8)
# Calculate position to paste original image
x_center = (size - image.shape[1]) // 2
y_center = (size - image.shape[0]) // 2
# Copy the image into center of result image
new_img[y_center:y_center+image.shape[0],
x_center:x_center+image.shape[1]] = image
return new_img
@st.cache_resource
def load_models():
"""Load both object detection and classification models"""
# Load object detection model
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
object_detection_model = torch.load("fasterrcnn_resnet50_fpn_090824.pth", map_location=device)
object_detection_model.to(device)
object_detection_model.eval()
# Load classification model
classification_model = tf.keras.models.load_model('custom_model.h5')
return object_detection_model, classification_model, device
def perform_object_detection(image, model, device):
original_size = image.size
target_size = (256, 256)
frame_resized = cv2.resize(np.array(image), dsize=target_size, interpolation=cv2.INTER_AREA)
frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_RGB2BGR).astype(np.float32)
frame_rgb /= 255.0
frame_rgb = frame_rgb.transpose(2, 0, 1)
frame_rgb = torch.from_numpy(frame_rgb).float().unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(frame_rgb)
boxes = outputs[0]['boxes'].cpu().detach().numpy().astype(np.int32)
labels = outputs[0]['labels'].cpu().detach().numpy().astype(np.int32)
scores = outputs[0]['scores'].cpu().detach().numpy()
result_image = frame_resized.copy()
cropped_images = []
detected_boxes = []
for i in range(len(boxes)):
if scores[i] >= 0.75:
x1, y1, x2, y2 = boxes[i]
if (int(labels[i])-1) == 1 or (int(labels[i])-1) == 0:
color = (0, 0, 255)
label_text = 'Flame stone surface'
# Scale coordinates to original image size
original_h, original_w = original_size[::-1]
scale_h, scale_w = original_h / target_size[0], original_w / target_size[1]
x1_orig, y1_orig = int(x1 * scale_w), int(y1 * scale_h)
x2_orig, y2_orig = int(x2 * scale_w), int(y2 * scale_h)
# Crop and process detected region
cropped_image = np.array(image)[y1_orig:y2_orig, x1_orig:x2_orig]
resized_crop = resize_to_square(cropped_image)
cropped_images.append(resized_crop)
detected_boxes.append((x1, y1, x2, y2))
# Draw bounding box
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 3)
cv2.putText(result_image, label_text, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return Image.fromarray(result_image), cropped_images, detected_boxes
def preprocess_image(image):
"""Preprocess the image for classification"""
img_array = np.array(image)
img_array = cv2.resize(img_array, (256, 256))
img_array = img_array.astype('float32') / 255.0
return img_array
def get_top_predictions(prediction, class_names, top_k=5):
"""Get top k predictions with their probabilities"""
top_indices = prediction.argsort()[0][-top_k:][::-1]
top_predictions = [
(class_names[i], float(prediction[0][i]) * 100)
for i in top_indices
]
return top_predictions
def main():
st.title("🪨 Stone Detection & Classification")
st.write("Upload an image to detect and classify stone surfaces")
if 'predictions' not in st.session_state:
st.session_state.predictions = None
col1, col2 = st.columns(2)
with col1:
st.subheader("Upload Image")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
with st.spinner('Processing image...'):
try:
# Load both models
object_detection_model, classification_model, device = load_models()
# Perform object detection
result_image, cropped_images, detected_boxes = perform_object_detection(
image, object_detection_model, device
)
if not cropped_images:
st.warning("No stone surfaces detected in the image")
return
# Display detection results
st.subheader("Detection Results")
st.image(result_image, caption="Detected Stone Surfaces", use_column_width=True)
# Process each detected region
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
all_predictions = []
for idx, cropped_image in enumerate(cropped_images):
processed_image = preprocess_image(cropped_image)
prediction = classification_model.predict(
np.expand_dims(processed_image, axis=0)
)
top_predictions = get_top_predictions(prediction, class_names)
all_predictions.append(top_predictions)
# Store in session state
st.session_state.predictions = all_predictions
except Exception as e:
st.error(f"Error during processing: {str(e)}")
with col2:
st.subheader("Classification Results")
if st.session_state.predictions is not None:
for idx, predictions in enumerate(st.session_state.predictions):
st.markdown(f"### Region {idx + 1}")
# Display main prediction
top_class, top_confidence = predictions[0]
st.markdown(f"**Primary Prediction: Grade {top_class}**")
st.markdown(f"**Confidence: {top_confidence:.2f}%**")
st.progress(top_confidence / 100)
# Display all predictions for this region
st.markdown("**Top 5 Predictions**")
for class_name, confidence in predictions:
col_label, col_bar, col_value = st.columns([2, 6, 2])
with col_label:
st.write(f"Grade {class_name}")
with col_bar:
st.progress(confidence / 100)
with col_value:
st.write(f"{confidence:.2f}%")
st.markdown("---")
else:
st.info("Upload an image to see detection and classification results")
if __name__ == "__main__":
main() |