import streamlit as st import tensorflow as tf import numpy as np import cv2 from PIL import Image import io # Set page config st.set_page_config( page_title="Stone Classification", page_icon="🪨", layout="wide" ) # Custom CSS to improve the appearance st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model(): """Load the trained model""" return tf.keras.models.load_model('custom_model.h5') def preprocess_image(image): """Preprocess the uploaded image""" # # Convert to RGB if needed # if image.mode != 'RGB': # image = image.convert('RGB') # Convert to numpy array img_array = np.array(image) # # Convert to RGB if needed # if len(img_array.shape) == 2: # Grayscale # img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB) # elif img_array.shape[2] == 4: # RGBA # img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB) # # Preprocess image similar to training # img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV) # img_hsv[:, :, 2] = cv2.equalizeHist(img_hsv[:, :, 2]) # img_array = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB) # # Adjust brightness # target_brightness = 150 # current_brightness = np.mean(img_array) # alpha = target_brightness / (current_brightness + 1e-5) # img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=0) # # Apply Gaussian blur # img_array = cv2.GaussianBlur(img_array, (5, 5), 0) # Resize img_array = cv2.resize(img_array, (256, 256)) # Normalize img_array = img_array.astype('float32') / 255.0 return img_array def get_top_predictions(prediction, class_names, top_k=5): """Get top k predictions with their probabilities""" # Get indices of top k predictions top_indices = prediction.argsort()[0][-top_k:][::-1] # Get corresponding class names and probabilities top_predictions = [ (class_names[i], float(prediction[0][i]) * 100) for i in top_indices ] return top_predictions def main(): # Title st.title("🪨 Stone Classification") st.write("Upload an image of a stone to classify its type") # Initialize session state for prediction if not exists if 'predictions' not in st.session_state: st.session_state.predictions = None # Create two columns col1, col2 = st.columns(2) with col1: st.subheader("Upload Image") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Display uploaded image image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_column_width=True) with st.spinner('Analyzing image...'): try: # Load model model = load_model() # Preprocess image processed_image = preprocess_image(image) # Make prediction prediction = model.predict(np.expand_dims(processed_image, axis=0)) class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7'] # Get top 5 predictions top_predictions = get_top_predictions(prediction, class_names) # Store in session state st.session_state.predictions = top_predictions except Exception as e: st.error(f"Error during prediction: {str(e)}") with col2: st.subheader("Prediction Results") if st.session_state.predictions is not None: # Create a card-like container for results results_container = st.container() with results_container: # Display main prediction st.markdown("