Browse files
@@ -8,42 +8,175 @@ from tensorflow.keras.applications import EfficientNetB0
8 |
from tensorflow.keras.applications.efficientnet import preprocess_input
9 |
import joblib
10 |
import io
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
def load_model_and_scaler():
48 |
"""Load the trained model and scaler"""
49 |
@@ -60,34 +193,34 @@ def color_histogram(image, bins=16):
60 |
hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
61 |
hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
62 |
hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
63 |
64 |
hist_r = hist_r / (np.sum(hist_r) + 1e-7)
65 |
hist_g = hist_g / (np.sum(hist_g) + 1e-7)
66 |
hist_b = hist_b / (np.sum(hist_b) + 1e-7)
67 |
68 |
return np.concatenate([hist_r, hist_g, hist_b])
69 |
70 |
def color_moments(image):
71 |
"""Calculate color moments features"""
72 |
img = image.astype(np.float32) / 255.0
73 |
moments = []
74 |
75 |
for i in range(3):
76 |
channel = img[:,:,i]
77 |
mean = np.mean(channel)
78 |
std = np.std(channel) + 1e-7
79 |
skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
80 |
moments.extend([mean, std, skewness])
81 |
82 |
return np.array(moments)
83 |
84 |
def dominant_color_descriptor(image, k=3):
85 |
"""Calculate dominant color descriptor"""
86 |
pixels = image.reshape(-1, 3).astype(np.float32)
87 |
88 |
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
89 |
90 |
91 |
92 |
_, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, flags)
93 |
unique, counts = np.unique(labels, return_counts=True)
@@ -100,16 +233,16 @@ def color_coherence_vector(image, k=3):
100 |
"""Calculate color coherence vector"""
101 |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
102 |
gray = np.uint8(gray)
103 |
104 |
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
105 |
num_labels, labels = cv2.connectedComponents(binary)
106 |
107 |
ccv = []
108 |
for i in range(1, min(k+1, num_labels)):
109 |
region_mask = (labels == i)
110 |
total_pixels = np.sum(region_mask)
111 |
ccv.extend([total_pixels, total_pixels])
112 |
113 |
ccv.extend([0] * (2 * k - len(ccv)))
114 |
return np.array(ccv[:2*k])
115 |
@@ -119,13 +252,13 @@ def create_vit_feature_extractor():
119 |
input_shape = (256, 256, 3)
120 |
inputs = layers.Input(shape=input_shape)
121 |
x = layers.Lambda(preprocess_input)(inputs)
122 |
123 |
base_model = EfficientNetB0(
124 |
125 |
126 |
127 |
128 |
129 |
x = layers.GlobalAveragePooling2D()(base_model.output)
130 |
return models.Model(inputs=inputs, outputs=x)
131 |
@@ -136,21 +269,21 @@ def extract_features(image):
136 |
moment_features = color_moments(image)
137 |
dominant_features = dominant_color_descriptor(image)
138 |
ccv_features = color_coherence_vector(image)
139 |
140 |
traditional_features = np.concatenate([
141 |
142 |
143 |
144 |
145 |
146 |
147 |
# Deep features using ViT
148 |
feature_extractor = create_vit_feature_extractor()
149 |
vit_features = feature_extractor.predict(
150 |
np.expand_dims(image, axis=0),
151 |
152 |
153 |
154 |
# Combine all features
155 |
return np.concatenate([traditional_features, vit_features.flatten()])
156 |
@@ -159,100 +292,25 @@ def preprocess_image(image, scaler):
159 |
# Convert to RGB if needed
160 |
if image.mode != 'RGB':
161 |
image = image.convert('RGB')
162 |
163 |
# Convert to numpy array and resize
164 |
img_array = np.array(image)
165 |
img_array = cv2.resize(img_array, (256, 256))
166 |
img_array = img_array.astype('float32') / 255.0
167 |
168 |
# Extract all features
169 |
features = extract_features(img_array)
170 |
171 |
# Scale features using the provided scaler
172 |
scaled_features = scaler.transform(features.reshape(1, -1))
173 |
174 |
return scaled_features
175 |
176 |
177 |
"""Get top k predictions with their probabilities"""
178 |
top_indices = prediction.argsort()[0][-top_k:][::-1]
179 |
return [
180 |
(class_names[i], float(prediction[0][i]) * 100)
181 |
for i in top_indices
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
model, scaler = load_model_and_scaler()
190 |
if model is None or scaler is None:
191 |
st.error("Failed to load model or scaler. Please ensure both files exist.")
192 |
193 |
194 |
# Initialize session state
195 |
if 'predictions' not in st.session_state:
196 |
st.session_state.predictions = None
197 |
198 |
col1, col2 = st.columns(2)
199 |
200 |
with col1:
201 |
st.subheader("Upload Image")
202 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
203 |
204 |
if uploaded_file is not None:
205 |
206 |
image =
207 |
st.image(image, caption="Uploaded Image", use_column_width=True)
208 |
209 |
with st.spinner('Analyzing image...'):
210 |
processed_image = preprocess_image(image, scaler)
211 |
prediction = model.predict(processed_image, verbose=0)
212 |
213 |
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
214 |
st.session_state.predictions = get_top_predictions(prediction, class_names)
215 |
216 |
except Exception as e:
217 |
st.error(f"Error processing image: {str(e)}")
218 |
219 |
with col2:
220 |
st.subheader("Prediction Results")
221 |
if st.session_state.predictions:
222 |
# Display main prediction
223 |
top_class, top_confidence = st.session_state.predictions[0]
224 |
225 |
226 |
<div class='prediction-card'>
227 |
<h3>Primary Prediction: Grade {top_class}</h3>
228 |
<h3>Confidence: {top_confidence:.2f}%</h3>
229 |
230 |
231 |
232 |
233 |
234 |
# Display confidence bar
235 |
st.progress(top_confidence / 100)
236 |
237 |
# Display top 5 predictions
238 |
st.markdown("### Top 5 Predictions")
239 |
st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
240 |
241 |
for class_name, confidence in st.session_state.predictions:
242 |
cols = st.columns([2, 6, 2])
243 |
with cols[0]:
244 |
st.write(f"Grade {class_name}")
245 |
with cols[1]:
246 |
st.progress(confidence / 100)
247 |
with cols[2]:
248 |
249 |
250 |
st.markdown("</div>", unsafe_allow_html=True)
251 |
252 |
-"Upload an image to see the predictions")
253 |
254 |
255 |
st.markdown("Made with ❤️ using Streamlit")
256 |
257 |
if __name__ == "__main__":
258 |
8 |
from tensorflow.keras.applications.efficientnet import preprocess_input
9 |
import joblib
10 |
import io
11 |
import os
12 |
13 |
# Add Cloudinary import
14 |
import cloudinary
15 |
import cloudinary.uploader
16 |
from cloudinary.utils import cloudinary_url
17 |
18 |
# Cloudinary Configuration
19 |
20 |
cloud_name = os.getenv("CLOUD"),
21 |
api_key = os.getenv("API"),
22 |
api_secret = os.getenv("SECRET"),
23 |
24 |
25 |
26 |
def upload_to_cloudinary(file_path, label):
27 |
28 |
Upload file to Cloudinary with specified label as folder
29 |
30 |
31 |
# Upload to Cloudinary
32 |
upload_result = cloudinary.uploader.upload(
33 |
34 |
35 |
36 |
37 |
38 |
# Generate optimized URLs
39 |
optimize_url, _ = cloudinary_url(
40 |
41 |
42 |
43 |
44 |
45 |
auto_crop_url, _ = cloudinary_url(
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
return {
54 |
"upload_result": upload_result,
55 |
"optimize_url": optimize_url,
56 |
"auto_crop_url": auto_crop_url
57 |
58 |
59 |
except Exception as e:
60 |
return f"Error uploading to Cloudinary: {str(e)}"
61 |
62 |
def main():
63 |
st.title("🨨 Phân loại đá")
64 |
st.write("Tải lên hình ảnh của một viên đá để phân loại loại của nó.")
65 |
66 |
# Load model and scaler
67 |
model, scaler = load_model_and_scaler()
68 |
if model is None or scaler is None:
69 |
st.error("Không thể tải mô hình hoặc bộ chuẩn hóa. Vui lòng đảm bảo rằng cả hai tệp đều tồn tại.")
70 |
71 |
72 |
# Initialize session state
73 |
if 'predictions' not in st.session_state:
74 |
st.session_state.predictions = None
75 |
if 'uploaded_image' not in st.session_state:
76 |
st.session_state.uploaded_image = None
77 |
78 |
col1, col2 = st.columns(2)
79 |
80 |
with col1:
81 |
st.subheader("Tải lên Hình ảnh")
82 |
uploaded_file = st.file_uploader("Chọn hình ảnh...", type=["jpg", "jpeg", "png"])
83 |
84 |
if uploaded_file is not None:
85 |
86 |
image =
87 |
st.image(image, caption="Hình ảnh đã tải lên", use_column_width=True)
88 |
st.session_state.uploaded_image = image
89 |
90 |
with st.spinner('Đang phân tích hình ảnh...'):
91 |
processed_image = preprocess_image(image, scaler)
92 |
prediction = model.predict(processed_image, verbose=0)
93 |
94 |
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
95 |
st.session_state.predictions = get_top_predictions(prediction, class_names)
96 |
97 |
except Exception as e:
98 |
st.error(f"Lỗi khi xử lý hình ảnh: {str(e)}")
99 |
100 |
with col2:
101 |
st.subheader("Kết quả Dự đoán")
102 |
if st.session_state.predictions:
103 |
# Display main prediction
104 |
top_class, top_confidence = st.session_state.predictions[0]
105 |
106 |
107 |
<div class='prediction-card'>
108 |
<h3>Dự đoán chính: Màu {top_class}</h3>
109 |
<h3>Độ tin cậy: {top_confidence:.2f}%</h3>
110 |
111 |
112 |
113 |
114 |
115 |
# Display confidence bar
116 |
st.progress(top_confidence / 100)
117 |
118 |
# Display top 5 predictions
119 |
st.markdown("### 5 Dự đoán hàng đầu")
120 |
st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
121 |
122 |
for class_name, confidence in st.session_state.predictions:
123 |
124 |
f"**Màu {class_name}: Độ tin cậy {confidence:.2f}%**"
125 |
126 |
st.progress(confidence / 100)
127 |
128 |
st.markdown("</div>", unsafe_allow_html=True)
129 |
130 |
# User Confirmation Section
131 |
st.markdown("### Xác nhận độ chính xác của mô hình")
132 |
st.write("Giúp chúng tôi cải thiện mô hình bằng cách xác nhận độ chính xác của dự đoán.")
133 |
134 |
# Accuracy Radio Button
135 |
accuracy_option =
136 |
"Dự đoán có chính xác không?",
137 |
["Chọn", "Chính xác", "Không chính xác"],
138 |
139 |
140 |
141 |
if accuracy_option == "Không chính xác":
142 |
# Input for correct grade
143 |
correct_grade = st.selectbox(
144 |
"Chọn màu đá đúng:",
145 |
['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7'],
146 |
147 |
placeholder="Chọn màu đúng"
148 |
149 |
150 |
# Upload button
151 |
if st.button("Tải lên Hình ảnh để sửa chữa"):
152 |
if correct_grade and st.session_state.uploaded_image:
153 |
# Save the image temporarily
154 |
temp_image_path = f"temp_image_{hash(}.png"
155 |
156 |
157 |
158 |
# Upload to Cloudinary
159 |
cloudinary_result = upload_to_cloudinary(temp_image_path, correct_grade)
160 |
161 |
if isinstance(cloudinary_result, dict):
162 |
st.success(f"Hình ảnh đã được tải lên thành công cho màu {correct_grade}")
163 |
st.write(f"URL công khai: {cloudinary_result['upload_result']['secure_url']}")
164 |
165 |
166 |
167 |
# Clean up temporary file
168 |
169 |
170 |
except Exception as e:
171 |
st.error(f"Tải lên thất bại: {str(e)}")
172 |
173 |
st.warning("Vui lòng chọn màu đúng trước khi tải lên.")
174 |
175 |
+"Tải lên hình ảnh để xem các dự đoán.")
176 |
177 |
178 |
st.markdown("Tạo bởi ❤️ với Streamlit")
179 |
180 |
def load_model_and_scaler():
181 |
"""Load the trained model and scaler"""
182 |
193 |
hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
194 |
hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
195 |
hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
196 |
197 |
hist_r = hist_r / (np.sum(hist_r) + 1e-7)
198 |
hist_g = hist_g / (np.sum(hist_g) + 1e-7)
199 |
hist_b = hist_b / (np.sum(hist_b) + 1e-7)
200 |
201 |
return np.concatenate([hist_r, hist_g, hist_b])
202 |
203 |
def color_moments(image):
204 |
"""Calculate color moments features"""
205 |
img = image.astype(np.float32) / 255.0
206 |
moments = []
207 |
208 |
for i in range(3):
209 |
channel = img[:,:,i]
210 |
mean = np.mean(channel)
211 |
std = np.std(channel) + 1e-7
212 |
skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
213 |
moments.extend([mean, std, skewness])
214 |
215 |
return np.array(moments)
216 |
217 |
def dominant_color_descriptor(image, k=3):
218 |
"""Calculate dominant color descriptor"""
219 |
pixels = image.reshape(-1, 3).astype(np.float32)
220 |
221 |
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
222 |
223 |
224 |
225 |
_, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, flags)
226 |
unique, counts = np.unique(labels, return_counts=True)
233 |
"""Calculate color coherence vector"""
234 |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
235 |
gray = np.uint8(gray)
236 |
237 |
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
238 |
num_labels, labels = cv2.connectedComponents(binary)
239 |
240 |
ccv = []
241 |
for i in range(1, min(k+1, num_labels)):
242 |
region_mask = (labels == i)
243 |
total_pixels = np.sum(region_mask)
244 |
ccv.extend([total_pixels, total_pixels])
245 |
246 |
ccv.extend([0] * (2 * k - len(ccv)))
247 |
return np.array(ccv[:2*k])
248 |
252 |
input_shape = (256, 256, 3)
253 |
inputs = layers.Input(shape=input_shape)
254 |
x = layers.Lambda(preprocess_input)(inputs)
255 |
256 |
base_model = EfficientNetB0(
257 |
258 |
259 |
260 |
261 |
262 |
x = layers.GlobalAveragePooling2D()(base_model.output)
263 |
return models.Model(inputs=inputs, outputs=x)
264 |
269 |
moment_features = color_moments(image)
270 |
dominant_features = dominant_color_descriptor(image)
271 |
ccv_features = color_coherence_vector(image)
272 |
273 |
traditional_features = np.concatenate([
274 |
275 |
276 |
277 |
278 |
279 |
280 |
# Deep features using ViT
281 |
feature_extractor = create_vit_feature_extractor()
282 |
vit_features = feature_extractor.predict(
283 |
np.expand_dims(image, axis=0),
284 |
285 |
286 |
287 |
# Combine all features
288 |
return np.concatenate([traditional_features, vit_features.flatten()])
289 |
292 |
# Convert to RGB if needed
293 |
if image.mode != 'RGB':
294 |
image = image.convert('RGB')
295 |
296 |
# Convert to numpy array and resize
297 |
img_array = np.array(image)
298 |
img_array = cv2.resize(img_array, (256, 256))
299 |
img_array = img_array.astype('float32') / 255.0
300 |
301 |
# Extract all features
302 |
features = extract_features(img_array)
303 |
304 |
# Scale features using the provided scaler
305 |
scaled_features = scaler.transform(features.reshape(1, -1))
306 |
307 |
return scaled_features
308 |
309 |
def get_top_predictions(prediction, class_names):
310 |
# Extract the top 5 predictions with confidence values
311 |
probabilities = tf.nn.softmax(prediction[0]).numpy()
312 |
top_indices = np.argsort(probabilities)[-5:][::-1]
313 |
return [(class_names[i], probabilities[i] * 100) for i in top_indices]
314 |
315 |
if __name__ == "__main__":
316 |