SonFox2920's picture
Update app.py
029bf0e verified
raw
history blame
6.25 kB
import streamlit as st
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
import io
# Set page config
st.set_page_config(
page_title="Stone Classification",
page_icon="🪨",
layout="wide"
)
# Custom CSS to improve the appearance
st.markdown("""
<style>
.main {
padding: 2rem;
}
.stButton>button {
width: 100%;
margin-top: 1rem;
}
.upload-text {
text-align: center;
padding: 2rem;
}
.prediction-card {
padding: 2rem;
border-radius: 0.5rem;
background-color: #f0f2f6;
margin: 1rem 0;
}
.top-predictions {
margin-top: 2rem;
padding: 1rem;
background-color: white;
border-radius: 0.5rem;
box-shadow: 0 1px 3px rgba(0,0,0,0.12);
}
.prediction-bar {
display: flex;
align-items: center;
margin: 0.5rem 0;
}
.prediction-label {
width: 100px;
font-weight: 500;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model():
"""Load the trained model"""
return tf.keras.models.load_model('custom_model.h5')
def preprocess_image(image):
"""Preprocess the uploaded image"""
# # Convert to RGB if needed
# if image.mode != 'RGB':
# image = image.convert('RGB')
# Convert to numpy array
img_array = np.array(image)
# # Convert to RGB if needed
# if len(img_array.shape) == 2: # Grayscale
# img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
# elif img_array.shape[2] == 4: # RGBA
# img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
# # Preprocess image similar to training
# img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV)
# img_hsv[:, :, 2] = cv2.equalizeHist(img_hsv[:, :, 2])
# img_array = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
# # Adjust brightness
# target_brightness = 150
# current_brightness = np.mean(img_array)
# alpha = target_brightness / (current_brightness + 1e-5)
# img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=0)
# # Apply Gaussian blur
# img_array = cv2.GaussianBlur(img_array, (5, 5), 0)
# Resize
img_array = cv2.resize(img_array, (256, 256))
# Normalize
img_array = img_array.astype('float32') / 255.0
return img_array
def get_top_predictions(prediction, class_names, top_k=5):
"""Get top k predictions with their probabilities"""
# Get indices of top k predictions
top_indices = prediction.argsort()[0][-top_k:][::-1]
# Get corresponding class names and probabilities
top_predictions = [
(class_names[i], float(prediction[0][i]) * 100)
for i in top_indices
]
return top_predictions
def main():
# Title
st.title("🪨 Stone Classification")
st.write("Upload an image of a stone to classify its type")
# Initialize session state for prediction if not exists
if 'predictions' not in st.session_state:
st.session_state.predictions = None
# Create two columns
col1, col2 = st.columns(2)
with col1:
st.subheader("Upload Image")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
with st.spinner('Analyzing image...'):
try:
# Load model
model = load_model()
# Preprocess image
processed_image = preprocess_image(image)
# Make prediction
prediction = model.predict(np.expand_dims(processed_image, axis=0))
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
# Get top 5 predictions
top_predictions = get_top_predictions(prediction, class_names)
# Store in session state
st.session_state.predictions = top_predictions
except Exception as e:
st.error(f"Error during prediction: {str(e)}")
with col2:
st.subheader("Prediction Results")
if st.session_state.predictions is not None:
# Create a card-like container for results
results_container = st.container()
with results_container:
# Display main prediction
st.markdown("<div class='prediction-card'>", unsafe_allow_html=True)
top_class, top_confidence = st.session_state.predictions[0]
st.markdown(f"### Primary Prediction: Grade {top_class}")
st.markdown(f"### Confidence: {top_confidence:.2f}%")
st.markdown("</div>", unsafe_allow_html=True)
# Display confidence bar for top prediction
st.progress(top_confidence / 100)
# Display top 5 predictions
st.markdown("### Top 5 Predictions")
st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
# Create a Streamlit container for the predictions
for class_name, confidence in st.session_state.predictions:
col_label, col_bar, col_value = st.columns([2, 6, 2])
with col_label:
st.write(f"Grade {class_name}")
with col_bar:
st.progress(confidence / 100)
with col_value:
st.write(f"{confidence:.2f}%")
st.markdown("</div>", unsafe_allow_html=True)
else:
st.info("Upload an image and click 'Predict' to see the results")
# Footer
st.markdown("---")
st.markdown("Made with ❤️ using Streamlit")
if __name__ == "__main__":
main()