SonFox2920 commited on
Commit
d8f1dce
·
verified ·
1 Parent(s): 015a87b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -136
app.py CHANGED
@@ -3,12 +3,7 @@ import tensorflow as tf
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
6
- from tensorflow.keras import layers, models
7
- from tensorflow.keras.applications import EfficientNetB0
8
- from tensorflow.keras.applications.efficientnet import preprocess_input
9
- import joblib
10
  import io
11
- import os
12
 
13
  # Set page config
14
  st.set_page_config(
@@ -17,7 +12,7 @@ st.set_page_config(
17
  layout="wide"
18
  )
19
 
20
- # Custom CSS with improved styling
21
  st.markdown("""
22
  <style>
23
  .main {
@@ -27,10 +22,14 @@ st.markdown("""
27
  width: 100%;
28
  margin-top: 1rem;
29
  }
 
 
 
 
30
  .prediction-card {
31
  padding: 2rem;
32
  border-radius: 0.5rem;
33
- background-color: #d7d7d9;
34
  margin: 1rem 0;
35
  }
36
  .top-predictions {
@@ -40,154 +39,150 @@ st.markdown("""
40
  border-radius: 0.5rem;
41
  box-shadow: 0 1px 3px rgba(0,0,0,0.12);
42
  }
43
- .survey-card {
44
- padding: 1rem;
45
- background-color: #f0f2f6;
46
- border-radius: 0.5rem;
47
- margin-top: 1rem;
 
 
 
48
  }
49
  </style>
50
  """, unsafe_allow_html=True)
51
 
52
- # from mega import Mega
53
-
54
- # # Đăng nhập vào tài khoản Mega
55
- # def upload_to_mega(file_path, folder_name):
56
- # """
57
- # Upload file to a specific folder on Mega.nz
58
- # """
59
- # try:
60
- # # Đăng nhập vào tài khoản Mega
61
- # mega = Mega()
62
- # m = mega.login(os.getenv('EMAIL'), os.getenv('PASSWORD'))
63
-
64
- # # Tìm thư mục đích
65
- # folder = m.find(folder_name)
66
-
67
- # if not folder:
68
- # # Nếu thư mục không tồn tại, hiển thị thông báo lỗi
69
- # return f"Thư mục '{folder_name}' không tồn tại!"
70
-
71
- # # Tải tệp lên thư mục
72
- # file = m.upload(file_path, folder[0])
73
- # return f"Upload thành công! Link: {m.get_upload_link(file)}"
74
-
75
- # except Exception as e:
76
- # return f"Lỗi khi tải lên Mega: {str(e)}"
77
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def main():
 
80
  st.title("🪨 Stone Classification")
81
  st.write("Upload an image of a stone to classify its type")
82
-
83
- # Load model and scaler
84
- model, scaler = load_model_and_scaler()
85
- if model is None or scaler is None:
86
- st.error("Failed to load model or scaler. Please ensure both files exist.")
87
- return
88
-
89
- # Initialize session state
90
  if 'predictions' not in st.session_state:
91
  st.session_state.predictions = None
92
- if 'uploaded_image' not in st.session_state:
93
- st.session_state.uploaded_image = None
94
-
95
  col1, col2 = st.columns(2)
96
-
97
  with col1:
98
  st.subheader("Upload Image")
99
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
100
-
101
  if uploaded_file is not None:
102
- try:
103
- image = Image.open(uploaded_file)
104
- st.image(image, caption="Uploaded Image", use_column_width=True)
105
- st.session_state.uploaded_image = image
106
-
107
- with st.spinner('Analyzing image...'):
108
- processed_image = preprocess_image(image, scaler)
109
- prediction = model.predict(processed_image, verbose=0)
110
-
 
 
 
 
 
111
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
112
- st.session_state.predictions = get_top_predictions(prediction, class_names)
113
-
114
- except Exception as e:
115
- st.error(f"Error processing image: {str(e)}")
116
-
 
 
 
 
 
117
  with col2:
118
  st.subheader("Prediction Results")
119
- if st.session_state.predictions:
120
- # Display main prediction
121
- top_class, top_confidence = st.session_state.predictions[0]
122
- st.markdown(
123
- f"""
124
- <div class='prediction-card'>
125
- <h3>Primary Prediction: Grade {top_class}</h3>
126
- <h3>Confidence: {top_confidence:.2f}%</h3>
127
- </div>
128
- """,
129
- unsafe_allow_html=True
130
- )
131
-
132
- # Display confidence bar
133
- st.progress(top_confidence / 100)
134
-
135
- # Display top 5 predictions
136
- st.markdown("### Top 5 Predictions")
137
- st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
138
-
139
- for class_name, confidence in st.session_state.predictions:
140
- cols = st.columns([2, 6, 2])
141
- with cols[0]:
142
- st.write(f"Grade {class_name}")
143
- with cols[1]:
144
- st.progress(confidence / 100)
145
- with cols[2]:
146
- st.write(f"{confidence:.2f}%")
147
-
148
- st.markdown("</div>", unsafe_allow_html=True)
149
-
150
- # User Survey
151
- st.markdown("<div class='survey-card'>", unsafe_allow_html=True)
152
- st.markdown("### Model Accuracy Survey")
153
- st.write("Mô hình có dự đoán chính xác màu sắc của đá trong ảnh này không?")
154
-
155
- # Accuracy Confirmation
156
- accuracy = st.radio(
157
- "Đánh giá độ chính xác",
158
- ["Chọn", "Chính xác", "Không chính xác"],
159
- index=0
160
- )
161
-
162
- if accuracy == "Không chính xác":
163
- # Color input for incorrect prediction
164
- correct_color = st.text_input(
165
- "Vui lòng nhập màu sắc chính xác của đá:",
166
- help="Ví dụ: 10, 9.7, 9.5, 9.2, v.v."
167
- )
168
-
169
- # if st.button("Gửi phản hồi và tải ảnh"):
170
- # if correct_color and st.session_state.uploaded_image:
171
- # # Save the image temporarily
172
- # temp_image_path = f"temp_image_{hash(uploaded_file.name)}.png"
173
- # st.session_state.uploaded_image.save(temp_image_path)
174
-
175
- # # Upload to Mega.nz
176
- # upload_result = upload_to_mega(temp_image_path, correct_color)
177
- # if "Upload thành công" in upload_result:
178
- # st.success(upload_result)
179
- # else:
180
- # st.error(upload_result)
181
-
182
- # # Clean up temporary file
183
- # os.remove(temp_image_path)
184
- # else:
185
- # st.warning("Vui lòng nhập màu sắc chính xác")
186
-
187
- st.markdown("</div>", unsafe_allow_html=True)
188
  else:
189
- st.info("Upload an image to see the predictions")
190
-
 
191
  st.markdown("---")
192
  st.markdown("Made with ❤️ using Streamlit")
193
 
 
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
 
 
 
 
6
  import io
 
7
 
8
  # Set page config
9
  st.set_page_config(
 
12
  layout="wide"
13
  )
14
 
15
+ # Custom CSS to improve the appearance
16
  st.markdown("""
17
  <style>
18
  .main {
 
22
  width: 100%;
23
  margin-top: 1rem;
24
  }
25
+ .upload-text {
26
+ text-align: center;
27
+ padding: 2rem;
28
+ }
29
  .prediction-card {
30
  padding: 2rem;
31
  border-radius: 0.5rem;
32
+ background-color: #f0f2f6;
33
  margin: 1rem 0;
34
  }
35
  .top-predictions {
 
39
  border-radius: 0.5rem;
40
  box-shadow: 0 1px 3px rgba(0,0,0,0.12);
41
  }
42
+ .prediction-bar {
43
+ display: flex;
44
+ align-items: center;
45
+ margin: 0.5rem 0;
46
+ }
47
+ .prediction-label {
48
+ width: 100px;
49
+ font-weight: 500;
50
  }
51
  </style>
52
  """, unsafe_allow_html=True)
53
 
54
+ @st.cache_resource
55
+ def load_model():
56
+ """Load the trained model"""
57
+ return tf.keras.models.load_model('custom_model.h5')
58
+
59
+ def preprocess_image(image):
60
+ """Preprocess the uploaded image"""
61
+ # # Convert to RGB if needed
62
+ # if image.mode != 'RGB':
63
+ # image = image.convert('RGB')
64
+
65
+ # Convert to numpy array
66
+ img_array = np.array(image)
67
+
68
+ # # Convert to RGB if needed
69
+ # if len(img_array.shape) == 2: # Grayscale
70
+ # img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
71
+ # elif img_array.shape[2] == 4: # RGBA
72
+ # img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
73
+
74
+ # # Preprocess image similar to training
75
+ # img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV)
76
+ # img_hsv[:, :, 2] = cv2.equalizeHist(img_hsv[:, :, 2])
77
+ # img_array = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
78
+
79
+ # # Adjust brightness
80
+ # target_brightness = 150
81
+ # current_brightness = np.mean(img_array)
82
+ # alpha = target_brightness / (current_brightness + 1e-5)
83
+ # img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=0)
84
+
85
+ # # Apply Gaussian blur
86
+ # img_array = cv2.GaussianBlur(img_array, (5, 5), 0)
87
+
88
+ # Resize
89
+ img_array = cv2.resize(img_array, (256, 256))
90
+
91
+ # Normalize
92
+ img_array = img_array.astype('float32') / 255.0
93
+
94
+ return img_array
95
+
96
+ def get_top_predictions(prediction, class_names, top_k=5):
97
+ """Get top k predictions with their probabilities"""
98
+ # Get indices of top k predictions
99
+ top_indices = prediction.argsort()[0][-top_k:][::-1]
100
+
101
+ # Get corresponding class names and probabilities
102
+ top_predictions = [
103
+ (class_names[i], float(prediction[0][i]) * 100)
104
+ for i in top_indices
105
+ ]
106
+
107
+ return top_predictions
108
 
109
  def main():
110
+ # Title
111
  st.title("🪨 Stone Classification")
112
  st.write("Upload an image of a stone to classify its type")
113
+
114
+ # Initialize session state for prediction if not exists
 
 
 
 
 
 
115
  if 'predictions' not in st.session_state:
116
  st.session_state.predictions = None
117
+
118
+ # Create two columns
 
119
  col1, col2 = st.columns(2)
120
+
121
  with col1:
122
  st.subheader("Upload Image")
123
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
124
+
125
  if uploaded_file is not None:
126
+ # Display uploaded image
127
+ image = Image.open(uploaded_file)
128
+ st.image(image, caption="Uploaded Image", use_column_width=True)
129
+
130
+ with st.spinner('Analyzing image...'):
131
+ try:
132
+ # Load model
133
+ model = load_model()
134
+
135
+ # Preprocess image
136
+ processed_image = preprocess_image(image)
137
+
138
+ # Make prediction
139
+ prediction = model.predict(np.expand_dims(processed_image, axis=0))
140
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
141
+
142
+ # Get top 5 predictions
143
+ top_predictions = get_top_predictions(prediction, class_names)
144
+
145
+ # Store in session state
146
+ st.session_state.predictions = top_predictions
147
+
148
+ except Exception as e:
149
+ st.error(f"Error during prediction: {str(e)}")
150
+
151
  with col2:
152
  st.subheader("Prediction Results")
153
+ if st.session_state.predictions is not None:
154
+ # Create a card-like container for results
155
+ results_container = st.container()
156
+ with results_container:
157
+ # Display main prediction
158
+ st.markdown("<div class='prediction-card'>", unsafe_allow_html=True)
159
+ top_class, top_confidence = st.session_state.predictions[0]
160
+ st.markdown(f"### Primary Prediction: Grade {top_class}")
161
+ st.markdown(f"### Confidence: {top_confidence:.2f}%")
162
+ st.markdown("</div>", unsafe_allow_html=True)
163
+
164
+ # Display confidence bar for top prediction
165
+ st.progress(top_confidence / 100)
166
+
167
+ # Display top 5 predictions
168
+ st.markdown("### Top 5 Predictions")
169
+ st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
170
+
171
+ # Create a Streamlit container for the predictions
172
+ for class_name, confidence in st.session_state.predictions:
173
+ col_label, col_bar, col_value = st.columns([2, 6, 2])
174
+ with col_label:
175
+ st.write(f"Grade {class_name}")
176
+ with col_bar:
177
+ st.progress(confidence / 100)
178
+ with col_value:
179
+ st.write(f"{confidence:.2f}%")
180
+
181
+ st.markdown("</div>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  else:
183
+ st.info("Upload an image and click 'Predict' to see the results")
184
+
185
+ # Footer
186
  st.markdown("---")
187
  st.markdown("Made with ❤️ using Streamlit")
188