SonFox2920's picture
Update app.py
1cede7c verified
raw
history blame
7.92 kB
import streamlit as st
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
import io
import torch
# Set page config
st.set_page_config(
page_title="Stone Detection & Classification",
page_icon="🪨",
layout="wide"
)
# Custom CSS to improve the appearance
st.markdown("""
<style>
.main {
padding: 2rem;
}
.stButton>button {
width: 100%;
margin-top: 1rem;
}
.upload-text {
text-align: center;
padding: 2rem;
}
</style>
""", unsafe_allow_html=True)
def resize_to_square(image):
"""Resize image to square while maintaining aspect ratio"""
size = max(image.shape[0], image.shape[1])
new_img = np.zeros((size, size, 3), dtype=np.uint8)
# Calculate position to paste original image
x_center = (size - image.shape[1]) // 2
y_center = (size - image.shape[0]) // 2
# Copy the image into center of result image
new_img[y_center:y_center+image.shape[0],
x_center:x_center+image.shape[1]] = image
return new_img
@st.cache_resource
def load_models():
"""Load both object detection and classification models"""
# Load object detection model
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
object_detection_model = torch.load("fasterrcnn_resnet50_fpn_090824.pth", map_location=device)
object_detection_model.to(device)
object_detection_model.eval()
# Load classification model
classification_model = tf.keras.models.load_model('custom_model.h5')
return object_detection_model, classification_model, device
def perform_object_detection(image, model, device):
original_size = image.size
target_size = (256, 256)
frame_resized = cv2.resize(np.array(image), dsize=target_size, interpolation=cv2.INTER_AREA)
frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_RGB2BGR).astype(np.float32)
frame_rgb /= 255.0
frame_rgb = frame_rgb.transpose(2, 0, 1)
frame_rgb = torch.from_numpy(frame_rgb).float().unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(frame_rgb)
boxes = outputs[0]['boxes'].cpu().detach().numpy().astype(np.int32)
labels = outputs[0]['labels'].cpu().detach().numpy().astype(np.int32)
scores = outputs[0]['scores'].cpu().detach().numpy()
result_image = frame_resized.copy()
cropped_images = []
detected_boxes = []
for i in range(len(boxes)):
if scores[i] >= 0.75:
x1, y1, x2, y2 = boxes[i]
if (int(labels[i])-1) == 1 or (int(labels[i])-1) == 0:
color = (0, 0, 255)
label_text = 'Flame stone surface'
# Scale coordinates to original image size
original_h, original_w = original_size[::-1]
scale_h, scale_w = original_h / target_size[0], original_w / target_size[1]
x1_orig, y1_orig = int(x1 * scale_w), int(y1 * scale_h)
x2_orig, y2_orig = int(x2 * scale_w), int(y2 * scale_h)
# Crop and process detected region
cropped_image = np.array(image)[y1_orig:y2_orig, x1_orig:x2_orig]
resized_crop = resize_to_square(cropped_image)
cropped_images.append(resized_crop)
detected_boxes.append((x1, y1, x2, y2))
# Draw bounding box
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 3)
cv2.putText(result_image, label_text, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return Image.fromarray(result_image), cropped_images, detected_boxes
def preprocess_image(image):
"""Preprocess the image for classification"""
img_array = np.array(image)
img_array = cv2.resize(img_array, (256, 256))
img_array = img_array.astype('float32') / 255.0
return img_array
def get_top_predictions(prediction, class_names, top_k=5):
"""Get top k predictions with their probabilities"""
top_indices = prediction.argsort()[0][-top_k:][::-1]
top_predictions = [
(class_names[i], float(prediction[0][i]) * 100)
for i in top_indices
]
return top_predictions
def main():
st.title("🪨 Stone Detection & Classification")
st.write("Upload an image to detect and classify stone surfaces")
if 'predictions' not in st.session_state:
st.session_state.predictions = None
col1, col2 = st.columns(2)
with col1:
st.subheader("Upload Image")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
with st.spinner('Processing image...'):
try:
# Load both models
object_detection_model, classification_model, device = load_models()
# Perform object detection
result_image, cropped_images, detected_boxes = perform_object_detection(
image, object_detection_model, device
)
if not cropped_images:
st.warning("No stone surfaces detected in the image")
return
# Display detection results
st.subheader("Detection Results")
st.image(result_image, caption="Detected Stone Surfaces", use_column_width=True)
# Process each detected region
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
all_predictions = []
for idx, cropped_image in enumerate(cropped_images):
processed_image = preprocess_image(cropped_image)
prediction = classification_model.predict(
np.expand_dims(processed_image, axis=0)
)
top_predictions = get_top_predictions(prediction, class_names)
all_predictions.append(top_predictions)
# Store in session state
st.session_state.predictions = all_predictions
except Exception as e:
st.error(f"Error during processing: {str(e)}")
with col2:
st.subheader("Classification Results")
if st.session_state.predictions is not None:
for idx, predictions in enumerate(st.session_state.predictions):
st.markdown(f"### Region {idx + 1}")
# Display main prediction
top_class, top_confidence = predictions[0]
st.markdown(f"**Primary Prediction: Grade {top_class}**")
st.markdown(f"**Confidence: {top_confidence:.2f}%**")
st.progress(top_confidence / 100)
# Display all predictions for this region
st.markdown("**Top 5 Predictions**")
for class_name, confidence in predictions:
col_label, col_bar, col_value = st.columns([2, 6, 2])
with col_label:
st.write(f"Grade {class_name}")
with col_bar:
st.progress(confidence / 100)
with col_value:
st.write(f"{confidence:.2f}%")
st.markdown("---")
else:
st.info("Upload an image to see detection and classification results")
if __name__ == "__main__":
main()