SonFox2920's picture
Update app.py
be85061 verified
raw
history blame
11.5 kB
import streamlit as st
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
import io
import torch
import cloudinary
import cloudinary.uploader
from cloudinary.utils import cloudinary_url
import os
import random
import string
# Cloudinary Configuration
cloudinary.config(
cloud_name = os.getenv("CLOUD"),
api_key = os.getenv("API"),
api_secret = os.getenv("SECRET"),
secure=True
)
# Set page config
st.set_page_config(
page_title="Stone Detection & Classification",
page_icon="🪨",
layout="wide"
)
def generate_random_filename(extension="png"):
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=8))
return f"temp_image_{random_string}.{extension}"
# Custom CSS to improve the appearance
st.markdown("""
<style>
.main {
padding: 2rem;
}
.stButton>button {
width: 100%;
margin-top: 1rem;
}
.upload-text {
text-align: center;
padding: 2rem;
}
</style>
""", unsafe_allow_html=True)
def upload_to_cloudinary(file_path, label):
"""
Upload file to Cloudinary with specified label as folder
"""
try:
# Upload to Cloudinary
upload_result = cloudinary.uploader.upload(
file_path,
folder=label,
public_id=f"{label}_{os.path.basename(file_path)}"
)
# Generate optimized URLs
optimize_url, _ = cloudinary_url(
upload_result['public_id'],
fetch_format="auto",
quality="auto"
)
auto_crop_url, _ = cloudinary_url(
upload_result['public_id'],
width=500,
height=500,
crop="auto",
gravity="auto"
)
return {
"upload_result": upload_result,
"optimize_url": optimize_url,
"auto_crop_url": auto_crop_url
}
except Exception as e:
return f"Error uploading to Cloudinary: {str(e)}"
def resize_to_square(image):
"""Resize image to square while maintaining aspect ratio"""
size = max(image.shape[0], image.shape[1])
new_img = np.zeros((size, size, 3), dtype=np.uint8)
# Calculate position to paste original image
x_center = (size - image.shape[1]) // 2
y_center = (size - image.shape[0]) // 2
# Copy the image into center of result image
new_img[y_center:y_center+image.shape[0],
x_center:x_center+image.shape[1]] = image
return new_img
@st.cache_resource
def load_models():
"""Load both object detection and classification models"""
# Load object detection model
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
object_detection_model = torch.load("fasterrcnn_resnet50_fpn_090824.pth", map_location=device)
object_detection_model.to(device)
object_detection_model.eval()
# Load classification model
classification_model = tf.keras.models.load_model('custom_model.h5')
return object_detection_model, classification_model, device
def perform_object_detection(image, model, device):
original_size = image.size
target_size = (256, 256)
frame_resized = cv2.resize(np.array(image), dsize=target_size, interpolation=cv2.INTER_AREA)
frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_RGB2BGR).astype(np.float32)
frame_rgb /= 255.0
frame_rgb = frame_rgb.transpose(2, 0, 1)
frame_rgb = torch.from_numpy(frame_rgb).float().unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(frame_rgb)
boxes = outputs[0]['boxes'].cpu().detach().numpy().astype(np.int32)
labels = outputs[0]['labels'].cpu().detach().numpy().astype(np.int32)
scores = outputs[0]['scores'].cpu().detach().numpy()
result_image = frame_resized.copy()
cropped_images = []
detected_boxes = []
for i in range(len(boxes)):
if scores[i] >= 0.75:
x1, y1, x2, y2 = boxes[i]
if (int(labels[i])-1) == 1 or (int(labels[i])-1) == 0:
color = (0, 0, 255)
label_text = f'Region {i}'
# Scale coordinates to original image size
original_h, original_w = original_size[::-1]
scale_h, scale_w = original_h / target_size[0], original_w / target_size[1]
x1_orig, y1_orig = int(x1 * scale_w), int(y1 * scale_h)
x2_orig, y2_orig = int(x2 * scale_w), int(y2 * scale_h)
# Crop and process detected region
cropped_image = np.array(image)[y1_orig:y2_orig, x1_orig:x2_orig]
# Check if image has 4 channels (RGBA), convert to RGB
if cropped_image.shape[-1] == 4:
cropped_image = cv2.cvtColor(cropped_image, cv2.COLOR_RGBA2RGB)
# Resize cropped image
resized_crop = resize_to_square(cropped_image)
cropped_images.append(resized_crop)
detected_boxes.append((x1, y1, x2, y2))
# Draw bounding box
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 3)
cv2.putText(result_image, label_text, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return Image.fromarray(result_image), cropped_images, detected_boxes
def preprocess_image(image):
"""Preprocess the image for classification"""
img_array = np.array(image)
img_array = cv2.resize(img_array, (256, 256))
img_array = img_array.astype('float32') / 255.0
return img_array
def get_top_predictions(prediction, class_names, top_k=5):
"""Get top k predictions with their probabilities"""
top_indices = prediction.argsort()[0][-top_k:][::-1]
top_predictions = [
(class_names[i], float(prediction[0][i]) * 100)
for i in top_indices
]
return top_predictions
def main():
st.title("🪨 Stone Detection & Classification")
st.write("Upload an image to detect and classify stone surfaces")
if 'predictions' not in st.session_state:
st.session_state.predictions = None
col1, col2 = st.columns(2)
with col1:
st.subheader("Upload Image")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
with st.spinner('Processing image...'):
try:
# Load both models
object_detection_model, classification_model, device = load_models()
# Perform object detection
result_image, cropped_images, detected_boxes = perform_object_detection(
image, object_detection_model, device
)
if not cropped_images:
st.warning("No stone surfaces detected in the image")
return
# Display detection results
st.subheader("Detection Results")
st.image(result_image, caption="Detected Stone Surfaces", use_column_width=True)
# Process each detected region
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
all_predictions = []
for idx, cropped_image in enumerate(cropped_images):
processed_image = preprocess_image(cropped_image)
prediction = classification_model.predict(
np.expand_dims(processed_image, axis=0)
)
top_predictions = get_top_predictions(prediction, class_names)
all_predictions.append(top_predictions)
# Store in session state
st.session_state.predictions = all_predictions
except Exception as e:
st.error(f"Error during processing: {str(e)}")
with col2:
st.subheader("Classification Results")
if st.session_state.predictions is not None:
for idx, predictions in enumerate(st.session_state.predictions):
st.markdown(f"### Region {idx + 1}")
# Display main prediction
top_class, top_confidence = predictions[0]
st.markdown(f"**Primary Prediction: Grade {top_class}**")
st.markdown(f"**Confidence: {top_confidence:.2f}%**")
st.progress(top_confidence / 100)
# Display all predictions for this region
st.markdown("**Top 5 Predictions**")
for class_name, confidence in predictions:
col_label, col_bar, col_value = st.columns([2, 6, 2])
with col_label:
st.write(f"Grade {class_name}")
with col_bar:
st.progress(confidence / 100)
with col_value:
st.write(f"{confidence:.2f}%")
st.markdown("---")
st.markdown("</div>", unsafe_allow_html=True)
# User Confirmation Section
st.markdown("### Xác nhận độ chính xác của mô hình")
st.write("Giúp chúng tôi cải thiện mô hình bằng cách xác nhận độ chính xác của dự đoán.")
# Accuracy Radio Button
accuracy_option = st.radio(
"Dự đoán có chính xác không?",
["Chọn", "Chính xác", "Không chính xác"],
index=0,
key=f"accuracy_radio_{idx}"
)
if accuracy_option == "Không chính xác":
# Input for correct grade
correct_grade = st.selectbox(
"Chọn màu đá đúng:",
['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7'],
index=None,
placeholder="Chọn màu đúng",
key=f"selectbox_correct_grade_{idx}"
)
# Chỉ thực hiện khi người dùng đã chọn giá trị trong selectbox
if correct_grade:
st.info(f"Đã chọn màu đúng: {correct_grade}")
# Resize hình ảnh xuống 256x256
resized_image = Image.fromarray(cropped_image).resize((256, 256))
temp_image_path = generate_random_filename()
# Lưu tệp resize tạm thời
resized_image.save(temp_image_path)
# Tải ảnh lên Cloudinary
cloudinary_result = upload_to_cloudinary(temp_image_path, correct_grade)
if isinstance(cloudinary_result, dict):
st.success(f"Hình ảnh đã được tải lên thành công cho màu {correct_grade}")
st.write(f"URL công khai: {cloudinary_result['upload_result']['secure_url']}")
else:
st.error(cloudinary_result)
else:
st.info("Upload an image to see detection and classification results")
if __name__ == "__main__":
main()