SonFox2920 commited on
Commit
29bcdf2
·
verified ·
1 Parent(s): b76e982

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ import io
7
+
8
+ # Set page config
9
+ st.set_page_config(
10
+ page_title="Stone Classification",
11
+ page_icon="🪨",
12
+ layout="wide"
13
+ )
14
+
15
+ # Custom CSS to improve the appearance
16
+ st.markdown("""
17
+ <style>
18
+ .main {
19
+ padding: 2rem;
20
+ }
21
+ .stButton>button {
22
+ width: 100%;
23
+ margin-top: 1rem;
24
+ }
25
+ .upload-text {
26
+ text-align: center;
27
+ padding: 2rem;
28
+ }
29
+ </style>
30
+ """, unsafe_allow_html=True)
31
+
32
+ @st.cache_resource
33
+ def load_model():
34
+ """Load the trained model"""
35
+ return tf.keras.models.load_model('custom_model.h5')
36
+
37
+ def preprocess_image(image):
38
+ """Preprocess the uploaded image"""
39
+ # Convert to RGB if needed
40
+ if image.mode != 'RGB':
41
+ image = image.convert('RGB')
42
+
43
+ # Convert to numpy array
44
+ img_array = np.array(image)
45
+
46
+ # Convert to RGB if needed
47
+ if len(img_array.shape) == 2: # Grayscale
48
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
49
+ elif img_array.shape[2] == 4: # RGBA
50
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
51
+
52
+ # Preprocess image similar to training
53
+ img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV)
54
+ img_hsv[:, :, 2] = cv2.equalizeHist(img_hsv[:, :, 2])
55
+ img_array = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
56
+
57
+ # Adjust brightness
58
+ target_brightness = 150
59
+ current_brightness = np.mean(img_array)
60
+ alpha = target_brightness / (current_brightness + 1e-5)
61
+ img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=0)
62
+
63
+ # Apply Gaussian blur
64
+ img_array = cv2.GaussianBlur(img_array, (5, 5), 0)
65
+
66
+ # Resize
67
+ img_array = cv2.resize(img_array, (256, 256))
68
+
69
+ # Normalize
70
+ img_array = img_array.astype('float32') / 255.0
71
+
72
+ return img_array
73
+
74
+ def main():
75
+ # Title
76
+ st.title("🪨 Stone Classification")
77
+ st.write("Upload an image of a stone to classify its type")
78
+
79
+ # Initialize session state for prediction if not exists
80
+ if 'prediction' not in st.session_state:
81
+ st.session_state.prediction = None
82
+ if 'confidence' not in st.session_state:
83
+ st.session_state.confidence = None
84
+
85
+ # Create two columns
86
+ col1, col2 = st.columns(2)
87
+
88
+ with col1:
89
+ st.subheader("Upload Image")
90
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
91
+
92
+ if uploaded_file is not None:
93
+ # Display uploaded image
94
+ image = Image.open(uploaded_file)
95
+ st.image(image, caption="Uploaded Image", use_column_width=True)
96
+
97
+ # Add predict button
98
+ if st.button("Predict"):
99
+ try:
100
+ # Load model
101
+ model = load_model()
102
+
103
+ # Preprocess image
104
+ processed_image = preprocess_image(image)
105
+
106
+ # Make prediction
107
+ prediction = model.predict(np.expand_dims(processed_image, axis=0))
108
+ class_names = ['Artificial', 'Nature'] # Replace with your actual class names
109
+
110
+ # Get prediction and confidence
111
+ predicted_class = class_names[np.argmax(prediction)]
112
+ confidence = float(np.max(prediction)) * 100
113
+
114
+ # Store in session state
115
+ st.session_state.prediction = predicted_class
116
+ st.session_state.confidence = confidence
117
+
118
+ except Exception as e:
119
+ st.error(f"Error during prediction: {str(e)}")
120
+
121
+ with col2:
122
+ st.subheader("Prediction Results")
123
+ if st.session_state.prediction is not None:
124
+ # Create a card-like container for results
125
+ results_container = st.container()
126
+ with results_container:
127
+ st.markdown("""
128
+ <style>
129
+ .prediction-card {
130
+ padding: 2rem;
131
+ border-radius: 0.5rem;
132
+ background-color: #f0f2f6;
133
+ margin: 1rem 0;
134
+ }
135
+ </style>
136
+ """, unsafe_allow_html=True)
137
+
138
+ st.markdown("<div class='prediction-card'>", unsafe_allow_html=True)
139
+ st.markdown(f"### Predicted Class: {st.session_state.prediction}")
140
+ st.markdown(f"### Confidence: {st.session_state.confidence:.2f}%")
141
+ st.markdown("</div>", unsafe_allow_html=True)
142
+
143
+ # Add confidence bar
144
+ st.progress(st.session_state.confidence / 100)
145
+ else:
146
+ st.info("Upload an image and click 'Predict' to see the results")
147
+
148
+ # Footer
149
+ st.markdown("---")
150
+ st.markdown("Made with ❤️ using Streamlit")
151
+
152
+ if __name__ == "__main__":
153
+ main()