import streamlit as st import tensorflow as tf import numpy as np import cv2 from PIL import Image import io import os import cv2 import numpy as np from tensorflow import keras from tensorflow.keras import layers, models from sklearn.preprocessing import StandardScaler from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score import matplotlib.pyplot as plt import random from tensorflow.keras import layers, models from tensorflow.keras.applications import EfficientNetB0 from tensorflow.keras.applications.efficientnet import preprocess_input from tensorflow.keras.layers import Lambda # Đảm bảo nhập Lambda từ tensorflow.keras.layers # Set page config st.set_page_config( page_title="Stone Classification", page_icon="🪨", layout="wide" ) # Custom CSS to improve the appearance st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model(): """Load the trained model""" return tf.keras.models.load_model('mlp_model.h5') def color_histogram(image, bins=16): # (Previous implementation remains the same) hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten() hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten() hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten() hist_r = hist_r / np.sum(hist_r) hist_g = hist_g / np.sum(hist_g) hist_b = hist_b / np.sum(hist_b) return np.concatenate([hist_r, hist_g, hist_b]) def color_moments(image): # (Previous implementation remains the same) img = image.astype(np.float32) / 255.0 moments = [] for i in range(3): # For each color channel channel = img[:,:,i] mean = np.mean(channel) std = np.std(channel) skewness = np.mean(((channel - mean) / std) ** 3) moments.extend([mean, std, skewness]) return np.array(moments) def dominant_color_descriptor(image, k=3): # (Previous implementation remains the same) pixels = image.reshape(-1, 3) criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2) flags = cv2.KMEANS_RANDOM_CENTERS try: _, labels, centers = cv2.kmeans(pixels.astype(np.float32), k, None, criteria, 10, flags) unique, counts = np.unique(labels, return_counts=True) percentages = counts / len(labels) dominant_colors = centers.flatten() color_percentages = percentages return np.concatenate([dominant_colors, color_percentages]) except: return np.zeros(2 * k) def color_coherence_vector(image, k=3): # Convert to grayscale gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # Convert the grayscale image to 8-bit format before applying threshold gray = np.uint8(gray) # Apply Otsu's thresholding method _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # Perform connected components analysis num_labels, labels = cv2.connectedComponents(binary) ccv = [] for i in range(1, min(k+1, num_labels)): region_mask = (labels == i) total_pixels = np.sum(region_mask) coherent_pixels = total_pixels ccv.extend([coherent_pixels, total_pixels]) while len(ccv) < 2 * k: ccv.append(0) return np.array(ccv) # ViT and Feature Extraction Functions (from previous implementation) # (Keeping the Patches, PatchEncoder, and create_vit_feature_extractor functions) def extract_features(image): """ Extract multiple features from an image """ color_hist = color_histogram(image) color_mom = color_moments(image) dom_color = dominant_color_descriptor(image) ccv = color_coherence_vector(image) return np.concatenate([color_hist, color_mom, dom_color, ccv]) from transformers import ViTFeatureExtractor, ViTModel import torch from tensorflow.keras import layers, models def create_vit_feature_extractor(input_shape=(256, 256, 3), num_classes=None): # Xây dựng mô hình ViT đã huấn luyện sẵn từ TensorFlow inputs = layers.Input(shape=input_shape) # Thêm lớp Lambda để tiền xử lý ảnh x = Lambda(preprocess_input, output_shape=input_shape)(inputs) # Xử lý ảnh đầu vào # Bạn có thể thay thế phần này bằng một mô hình ViT đã được huấn luyện sẵn. # Dưới đây là ví dụ dùng EfficientNetB0 thay vì ViT. # Tạo mô hình ViT hoặc sử dụng mô hình khác đã được huấn luyện sẵn vit_model = EfficientNetB0(include_top=False, weights='imagenet', input_tensor=x) # Trích xuất đặc trưng từ mô hình ViT x = layers.GlobalAveragePooling2D()(vit_model.output) if num_classes: x = layers.Dense(num_classes, activation='softmax')(x) # Thêm lớp phân loại (nếu có) return models.Model(inputs=inputs, outputs=x) def preprocess_image(image): """Preprocess the uploaded image""" # # Convert to RGB if needed # if image.mode != 'RGB': # image = image.convert('RGB') # Convert to numpy array img_array = np.array(image) # # Convert to RGB if needed # if len(img_array.shape) == 2: # Grayscale # img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB) # elif img_array.shape[2] == 4: # RGBA # img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB) # # Preprocess image similar to training # img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV) # img_hsv[:, :, 2] = cv2.equalizeHist(img_hsv[:, :, 2]) # img_array = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB) # # Adjust brightness # target_brightness = 150 # current_brightness = np.mean(img_array) # alpha = target_brightness / (current_brightness + 1e-5) # img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=0) # # Apply Gaussian blur # img_array = cv2.GaussianBlur(img_array, (5, 5), 0) # Resize img_array = cv2.resize(img_array, (256, 256)) # Normalize img_array = img_array.astype('float32') / 255.0 image_features = extract_features(img_array) vit_extractor = create_vit_feature_extractor() # Trích xuất đặc trưng ViT từ các hình ảnh image_vit = vit_extractor.predict(img_array) # Dự đoán cho tập train image_combined = np.concatenate([image_features, image_vit], axis=1) scaler = StandardScaler() image_scaled = scaler.fit_transform(image_combined) return image_scaled def get_top_predictions(prediction, class_names, top_k=5): """Get top k predictions with their probabilities""" # Get indices of top k predictions top_indices = prediction.argsort()[0][-top_k:][::-1] # Get corresponding class names and probabilities top_predictions = [ (class_names[i], float(prediction[0][i]) * 100) for i in top_indices ] return top_predictions def main(): # Title st.title("🪨 Stone Classification") st.write("Upload an image of a stone to classify its type") # Initialize session state for prediction if not exists if 'predictions' not in st.session_state: st.session_state.predictions = None # Create two columns col1, col2 = st.columns(2) with col1: st.subheader("Upload Image") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Display uploaded image image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_column_width=True) with st.spinner('Analyzing image...'): try: # Load model model = load_model() # Preprocess image processed_image = preprocess_image(image) # Make prediction prediction = model.predict(np.expand_dims(processed_image, axis=0)) class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7'] # Get top 5 predictions top_predictions = get_top_predictions(prediction, class_names) # Store in session state st.session_state.predictions = top_predictions except Exception as e: st.error(f"Error during prediction: {str(e)}") with col2: st.subheader("Prediction Results") if st.session_state.predictions is not None: # Create a card-like container for results results_container = st.container() with results_container: # Display main prediction st.markdown("
", unsafe_allow_html=True) top_class, top_confidence = st.session_state.predictions[0] st.markdown(f"### Primary Prediction: Grade {top_class}") st.markdown(f"### Confidence: {top_confidence:.2f}%") st.markdown("
", unsafe_allow_html=True) # Display confidence bar for top prediction st.progress(top_confidence / 100) # Display top 5 predictions st.markdown("### Top 5 Predictions") st.markdown("
", unsafe_allow_html=True) # Create a Streamlit container for the predictions for class_name, confidence in st.session_state.predictions: col_label, col_bar, col_value = st.columns([2, 6, 2]) with col_label: st.write(f"Grade {class_name}") with col_bar: st.progress(confidence / 100) with col_value: st.write(f"{confidence:.2f}%") st.markdown("
", unsafe_allow_html=True) else: st.info("Upload an image and click 'Predict' to see the results") # Footer st.markdown("---") st.markdown("Made with ❤️ using Streamlit") if __name__ == "__main__": main()