Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import tensorflow as tf
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from PIL import Image
|
6 |
+
import io
|
7 |
+
|
8 |
+
# Set page config
|
9 |
+
st.set_page_config(
|
10 |
+
page_title="Stone Classification",
|
11 |
+
page_icon="🪨",
|
12 |
+
layout="wide"
|
13 |
+
)
|
14 |
+
|
15 |
+
# Custom CSS to improve the appearance
|
16 |
+
st.markdown("""
|
17 |
+
<style>
|
18 |
+
.main {
|
19 |
+
padding: 2rem;
|
20 |
+
}
|
21 |
+
.stButton>button {
|
22 |
+
width: 100%;
|
23 |
+
margin-top: 1rem;
|
24 |
+
}
|
25 |
+
.upload-text {
|
26 |
+
text-align: center;
|
27 |
+
padding: 2rem;
|
28 |
+
}
|
29 |
+
</style>
|
30 |
+
""", unsafe_allow_html=True)
|
31 |
+
|
32 |
+
@st.cache_resource
|
33 |
+
def load_model():
|
34 |
+
"""Load the trained model"""
|
35 |
+
return tf.keras.models.load_model('custom_model.h5')
|
36 |
+
|
37 |
+
def preprocess_image(image):
|
38 |
+
"""Preprocess the uploaded image"""
|
39 |
+
# Convert to RGB if needed
|
40 |
+
if image.mode != 'RGB':
|
41 |
+
image = image.convert('RGB')
|
42 |
+
|
43 |
+
# Convert to numpy array
|
44 |
+
img_array = np.array(image)
|
45 |
+
|
46 |
+
# Convert to RGB if needed
|
47 |
+
if len(img_array.shape) == 2: # Grayscale
|
48 |
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
|
49 |
+
elif img_array.shape[2] == 4: # RGBA
|
50 |
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
|
51 |
+
|
52 |
+
# Preprocess image similar to training
|
53 |
+
img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV)
|
54 |
+
img_hsv[:, :, 2] = cv2.equalizeHist(img_hsv[:, :, 2])
|
55 |
+
img_array = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
|
56 |
+
|
57 |
+
# Adjust brightness
|
58 |
+
target_brightness = 150
|
59 |
+
current_brightness = np.mean(img_array)
|
60 |
+
alpha = target_brightness / (current_brightness + 1e-5)
|
61 |
+
img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=0)
|
62 |
+
|
63 |
+
# Apply Gaussian blur
|
64 |
+
img_array = cv2.GaussianBlur(img_array, (5, 5), 0)
|
65 |
+
|
66 |
+
# Resize
|
67 |
+
img_array = cv2.resize(img_array, (256, 256))
|
68 |
+
|
69 |
+
# Normalize
|
70 |
+
img_array = img_array.astype('float32') / 255.0
|
71 |
+
|
72 |
+
return img_array
|
73 |
+
|
74 |
+
def main():
|
75 |
+
# Title
|
76 |
+
st.title("🪨 Stone Classification")
|
77 |
+
st.write("Upload an image of a stone to classify its type")
|
78 |
+
|
79 |
+
# Initialize session state for prediction if not exists
|
80 |
+
if 'prediction' not in st.session_state:
|
81 |
+
st.session_state.prediction = None
|
82 |
+
if 'confidence' not in st.session_state:
|
83 |
+
st.session_state.confidence = None
|
84 |
+
|
85 |
+
# Create two columns
|
86 |
+
col1, col2 = st.columns(2)
|
87 |
+
|
88 |
+
with col1:
|
89 |
+
st.subheader("Upload Image")
|
90 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
91 |
+
|
92 |
+
if uploaded_file is not None:
|
93 |
+
# Display uploaded image
|
94 |
+
image = Image.open(uploaded_file)
|
95 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
96 |
+
|
97 |
+
# Add predict button
|
98 |
+
if st.button("Predict"):
|
99 |
+
try:
|
100 |
+
# Load model
|
101 |
+
model = load_model()
|
102 |
+
|
103 |
+
# Preprocess image
|
104 |
+
processed_image = preprocess_image(image)
|
105 |
+
|
106 |
+
# Make prediction
|
107 |
+
prediction = model.predict(np.expand_dims(processed_image, axis=0))
|
108 |
+
class_names = ['Artificial', 'Nature'] # Replace with your actual class names
|
109 |
+
|
110 |
+
# Get prediction and confidence
|
111 |
+
predicted_class = class_names[np.argmax(prediction)]
|
112 |
+
confidence = float(np.max(prediction)) * 100
|
113 |
+
|
114 |
+
# Store in session state
|
115 |
+
st.session_state.prediction = predicted_class
|
116 |
+
st.session_state.confidence = confidence
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
st.error(f"Error during prediction: {str(e)}")
|
120 |
+
|
121 |
+
with col2:
|
122 |
+
st.subheader("Prediction Results")
|
123 |
+
if st.session_state.prediction is not None:
|
124 |
+
# Create a card-like container for results
|
125 |
+
results_container = st.container()
|
126 |
+
with results_container:
|
127 |
+
st.markdown("""
|
128 |
+
<style>
|
129 |
+
.prediction-card {
|
130 |
+
padding: 2rem;
|
131 |
+
border-radius: 0.5rem;
|
132 |
+
background-color: #f0f2f6;
|
133 |
+
margin: 1rem 0;
|
134 |
+
}
|
135 |
+
</style>
|
136 |
+
""", unsafe_allow_html=True)
|
137 |
+
|
138 |
+
st.markdown("<div class='prediction-card'>", unsafe_allow_html=True)
|
139 |
+
st.markdown(f"### Predicted Class: {st.session_state.prediction}")
|
140 |
+
st.markdown(f"### Confidence: {st.session_state.confidence:.2f}%")
|
141 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
142 |
+
|
143 |
+
# Add confidence bar
|
144 |
+
st.progress(st.session_state.confidence / 100)
|
145 |
+
else:
|
146 |
+
st.info("Upload an image and click 'Predict' to see the results")
|
147 |
+
|
148 |
+
# Footer
|
149 |
+
st.markdown("---")
|
150 |
+
st.markdown("Made with ❤️ using Streamlit")
|
151 |
+
|
152 |
+
if __name__ == "__main__":
|
153 |
+
main()
|