File size: 8,334 Bytes
29bcdf2
 
 
 
 
bf836f1
 
 
33f7f53
435eb8d
 
29bcdf2
 
 
 
 
 
 
435eb8d
29bcdf2
 
 
 
 
 
 
 
 
aed03ad
 
 
d71712b
aed03ad
 
 
 
 
 
 
 
 
29bcdf2
 
 
33f7f53
29bcdf2
33f7f53
 
435eb8d
33f7f53
 
 
435eb8d
33f7f53
 
435eb8d
bf836f1
435eb8d
bf836f1
 
 
 
435eb8d
 
 
bf836f1
 
 
 
435eb8d
bf836f1
 
435eb8d
 
bf836f1
 
33f7f53
435eb8d
bf836f1
 
 
29bcdf2
bf836f1
435eb8d
 
bf836f1
 
972c8da
bf836f1
 
435eb8d
bf836f1
 
435eb8d
 
33f7f53
bf836f1
 
435eb8d
bf836f1
 
 
 
 
 
 
 
 
 
435eb8d
bf836f1
435eb8d
 
bf836f1
435eb8d
33f7f53
 
435eb8d
bf836f1
435eb8d
bf836f1
435eb8d
 
 
 
 
bf836f1
435eb8d
bf836f1
435eb8d
 
 
33f7f53
 
 
 
 
bf836f1
33f7f53
435eb8d
 
 
 
 
33f7f53
 
 
 
 
 
 
 
 
 
435eb8d
33f7f53
29bcdf2
e30bd59
 
 
 
435eb8d
29bcdf2
 
 
e30bd59
33f7f53
 
e30bd59
33f7f53
 
e30bd59
33f7f53
29bcdf2
aed03ad
 
 
435eb8d
aed03ad
 
 
 
29bcdf2
 
 
 
33f7f53
 
 
 
 
 
435eb8d
aed03ad
 
29bcdf2
 
 
 
 
 
 
 
435eb8d
 
 
 
 
33f7f53
435eb8d
1f2f1b7
 
435eb8d
1f2f1b7
435eb8d
 
29bcdf2
 
 
435eb8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29bcdf2
435eb8d
29bcdf2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import streamlit as st
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.applications.efficientnet import preprocess_input
import joblib
import io

# Set page config
st.set_page_config(
    page_title="Stone Classification",
    page_icon="🪨",
    layout="wide"
)

# Custom CSS with improved styling
st.markdown("""
    <style>
    .main {
        padding: 2rem;
    }
    .stButton>button {
        width: 100%;
        margin-top: 1rem;
    }
    .prediction-card {
        padding: 2rem;
        border-radius: 0.5rem;
        background-color: #d7d7d9;
        margin: 1rem 0;
    }
    .top-predictions {
        margin-top: 2rem;
        padding: 1rem;
        background-color: white;
        border-radius: 0.5rem;
        box-shadow: 0 1px 3px rgba(0,0,0,0.12);
    }
    </style>
    """, unsafe_allow_html=True)

# Cache the model loading
@st.cache_resource
def load_model_and_scaler():
    """Load the trained model and scaler"""
    try:
        model = tf.keras.models.load_model('mlp_model.h5')
        scaler = joblib.load('scaler.save')
        return model, scaler
    except Exception as e:
        st.error(f"Error loading model or scaler: {str(e)}")
        return None, None

def color_histogram(image, bins=16):
    """Calculate color histogram features"""
    hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
    hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
    hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
    
    hist_r = hist_r / (np.sum(hist_r) + 1e-7)
    hist_g = hist_g / (np.sum(hist_g) + 1e-7)
    hist_b = hist_b / (np.sum(hist_b) + 1e-7)
    
    return np.concatenate([hist_r, hist_g, hist_b])

def color_moments(image):
    """Calculate color moments features"""
    img = image.astype(np.float32) / 255.0
    moments = []
    
    for i in range(3):
        channel = img[:,:,i]
        mean = np.mean(channel)
        std = np.std(channel) + 1e-7
        skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
        moments.extend([mean, std, skewness])
    
    return np.array(moments)

def dominant_color_descriptor(image, k=3):
    """Calculate dominant color descriptor"""
    pixels = image.reshape(-1, 3).astype(np.float32)
    
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
    flags = cv2.KMEANS_RANDOM_CENTERS
    
    try:
        _, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, flags)
        unique, counts = np.unique(labels, return_counts=True)
        percentages = counts / len(labels)
        return np.concatenate([centers.flatten(), percentages])
    except Exception:
        return np.zeros(k * 4)

def color_coherence_vector(image, k=3):
    """Calculate color coherence vector"""
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    gray = np.uint8(gray)
    
    _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    num_labels, labels = cv2.connectedComponents(binary)
    
    ccv = []
    for i in range(1, min(k+1, num_labels)):
        region_mask = (labels == i)
        total_pixels = np.sum(region_mask)
        ccv.extend([total_pixels, total_pixels])
    
    ccv.extend([0] * (2 * k - len(ccv)))
    return np.array(ccv[:2*k])

@st.cache_resource
def create_vit_feature_extractor():
    """Create and cache the ViT feature extractor"""
    input_shape = (256, 256, 3)
    inputs = layers.Input(shape=input_shape)
    x = layers.Lambda(preprocess_input)(inputs)
    
    base_model = EfficientNetB0(
        include_top=False,
        weights='imagenet',
        input_tensor=x
    )
    
    x = layers.GlobalAveragePooling2D()(base_model.output)
    return models.Model(inputs=inputs, outputs=x)

def extract_features(image):
    """Extract all features from an image"""
    # Traditional features
    hist_features = color_histogram(image)
    moment_features = color_moments(image)
    dominant_features = dominant_color_descriptor(image)
    ccv_features = color_coherence_vector(image)
    
    traditional_features = np.concatenate([
        hist_features,
        moment_features,
        dominant_features,
        ccv_features
    ])
    
    # Deep features using ViT
    feature_extractor = create_vit_feature_extractor()
    vit_features = feature_extractor.predict(
        np.expand_dims(image, axis=0),
        verbose=0
    )
    
    # Combine all features
    return np.concatenate([traditional_features, vit_features.flatten()])

def preprocess_image(image, scaler):
    """Preprocess the uploaded image"""
    # Convert to RGB if needed
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Convert to numpy array and resize
    img_array = np.array(image)
    img_array = cv2.resize(img_array, (256, 256))
    img_array = img_array.astype('float32') / 255.0
    
    # Extract all features
    features = extract_features(img_array)
    
    # Scale features using the provided scaler
    scaled_features = scaler.transform(features.reshape(1, -1))
    
    return scaled_features

def get_top_predictions(prediction, class_names, top_k=5):
    """Get top k predictions with their probabilities"""
    top_indices = prediction.argsort()[0][-top_k:][::-1]
    return [
        (class_names[i], float(prediction[0][i]) * 100)
        for i in top_indices
    ]

def main():
    st.title("🪨 Stone Classification")
    st.write("Upload an image of a stone to classify its type")
    
    # Load model and scaler
    model, scaler = load_model_and_scaler()
    if model is None or scaler is None:
        st.error("Failed to load model or scaler. Please ensure both files exist.")
        return
    
    # Initialize session state
    if 'predictions' not in st.session_state:
        st.session_state.predictions = None
    
    col1, col2 = st.columns(2)
    
    with col1:
        st.subheader("Upload Image")
        uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
        
        if uploaded_file is not None:
            try:
                image = Image.open(uploaded_file)
                st.image(image, caption="Uploaded Image", use_column_width=True)
                
                with st.spinner('Analyzing image...'):
                    processed_image = preprocess_image(image, scaler)
                    prediction = model.predict(processed_image, verbose=0)
                    
                    class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
                    st.session_state.predictions = get_top_predictions(prediction, class_names)
                    
            except Exception as e:
                st.error(f"Error processing image: {str(e)}")
    
    with col2:
        st.subheader("Prediction Results")
        if st.session_state.predictions:
            # Display main prediction
            top_class, top_confidence = st.session_state.predictions[0]
            st.markdown(
                f"""
                <div class='prediction-card'>
                    <h3>Primary Prediction: Grade {top_class}</h3>
                    <h3>Confidence: {top_confidence:.2f}%</h3>
                </div>
                """,
                unsafe_allow_html=True
            )
            
            # Display confidence bar
            st.progress(top_confidence / 100)
            
            # Display top 5 predictions
            st.markdown("### Top 5 Predictions")
            st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
            
            for class_name, confidence in st.session_state.predictions:
                cols = st.columns([2, 6, 2])
                with cols[0]:
                    st.write(f"Grade {class_name}")
                with cols[1]:
                    st.progress(confidence / 100)
                with cols[2]:
                    st.write(f"{confidence:.2f}%")
            
            st.markdown("</div>", unsafe_allow_html=True)
        else:
            st.info("Upload an image to see the predictions")
    
    st.markdown("---")
    st.markdown("Made with ❤️ using Streamlit")

if __name__ == "__main__":
    main()