SonFox2920 commited on
Commit
c2ddd50
·
verified ·
1 Parent(s): 2f28467

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +168 -611
  2. requirements.txt +1 -9
app.py CHANGED
@@ -2,632 +2,189 @@ import streamlit as st
2
  import tensorflow as tf
3
  import numpy as np
4
  import cv2
5
- import pywt # Thư viện xử lý wavelet
6
  from PIL import Image
7
- from tensorflow.keras import layers, models
8
- from tensorflow.keras.applications import EfficientNetB0
9
- from tensorflow.keras.applications.efficientnet import preprocess_input
10
- import joblib
11
  import io
12
- import os
13
- import cv2
14
- import numpy as np
15
- from tensorflow import keras
16
- from tensorflow.keras import layers, models
17
- from sklearn.preprocessing import StandardScaler
18
- from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
19
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
20
- import matplotlib.pyplot as plt
21
- import random
22
- from keras.applications import ResNet50
23
- from tensorflow.keras.applications.resnet import preprocess_input
24
- from tensorflow.keras import layers, models
25
- from tensorflow.keras.applications.resnet import preprocess_input
26
- from tensorflow.keras.applications import EfficientNetB0
27
- from tensorflow.keras.applications.efficientnet import preprocess_input
28
- from tensorflow.keras.layers import Lambda # Đảm bảo nhập Lambda từ tensorflow.keras.layers
29
- from skimage.feature import graycomatrix, graycoprops
30
- from keras.applications import ResNet50
31
- from tensorflow.keras.applications.resnet import preprocess_input
32
 
33
- # Add Cloudinary import
34
- import cloudinary
35
- import cloudinary.uploader
36
- from cloudinary.utils import cloudinary_url
37
-
38
- # Cloudinary Configuration
39
- cloudinary.config(
40
- cloud_name = os.getenv("CLOUD"),
41
- api_key = os.getenv("API"),
42
- api_secret = os.getenv("SECRET"),
43
- secure=True
44
  )
45
 
46
- def upload_to_cloudinary(file_path, label):
47
- """
48
- Upload file to Cloudinary with specified label as folder
49
- """
50
- try:
51
- # Upload to Cloudinary
52
- upload_result = cloudinary.uploader.upload(
53
- file_path,
54
- folder=label,
55
- public_id=f"{label}_{os.path.basename(file_path)}"
56
- )
57
-
58
- # Generate optimized URLs
59
- optimize_url, _ = cloudinary_url(
60
- upload_result['public_id'],
61
- fetch_format="auto",
62
- quality="auto"
63
- )
64
-
65
- auto_crop_url, _ = cloudinary_url(
66
- upload_result['public_id'],
67
- width=500,
68
- height=500,
69
- crop="auto",
70
- gravity="auto"
71
- )
72
-
73
- return {
74
- "upload_result": upload_result,
75
- "optimize_url": optimize_url,
76
- "auto_crop_url": auto_crop_url
77
- }
78
-
79
- except Exception as e:
80
- return f"Error uploading to Cloudinary: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def main():
83
- st.title("Web App Phân loại đá")
84
- st.write("Tải lên hình ảnh của một viên đá để phân loại loại của nó.")
85
-
86
- # Load model and scaler
87
- model, scaler = load_model_and_scaler()
88
- if model is None or scaler is None:
89
- st.error("Không thể tải mô hình hoặc bộ chuẩn hóa. Vui lòng đảm bảo rằng cả hai tệp đều tồn tại.")
90
- return
91
-
92
- # Initialize session state
93
  if 'predictions' not in st.session_state:
94
  st.session_state.predictions = None
95
- if 'uploaded_image' not in st.session_state:
96
- st.session_state.uploaded_image = None
97
-
98
  col1, col2 = st.columns(2)
99
-
100
  with col1:
101
- st.subheader("Tải lên Hình ảnh")
102
- uploaded_file = st.file_uploader("Chọn hình ảnh...", type=["jpg", "jpeg", "png"])
103
-
104
  if uploaded_file is not None:
105
- try:
106
- image = Image.open(uploaded_file)
107
- st.image(image, caption="Hình ảnh đã tải lên", use_column_width=True)
108
- st.session_state.uploaded_image = image
109
-
110
- with st.spinner('Đang phân tích hình ảnh...'):
111
- processed_image = preprocess_image(image, scaler)
112
- prediction = model.predict(processed_image)
113
-
 
 
 
 
 
114
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
115
- st.session_state.predictions = get_top_predictions(prediction, class_names)
116
-
117
- except Exception as e:
118
- st.error(f"Lỗi khi xử lý hình ảnh: {str(e)}")
119
-
 
 
 
 
 
120
  with col2:
121
- st.subheader("Kết quả Dự đoán")
122
- if st.session_state.predictions:
123
- # Display main prediction
124
- top_class, top_confidence = st.session_state.predictions[0]
125
- st.markdown(
126
- f"""
127
- <div class='prediction-card'>
128
- <h3>Dự đoán chính: Màu {top_class}</h3>
129
- <h3>Độ tin cậy: {top_confidence:.2f}%</h3>
130
- </div>
131
- """,
132
- unsafe_allow_html=True
133
- )
134
-
135
- # Display confidence bar
136
- st.progress(float(top_confidence) / 100)
137
-
138
- # Display top 5 predictions
139
- st.markdown("### 5 Dự đoán hàng đầu")
140
- st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
141
-
142
- for class_name, confidence in st.session_state.predictions:
143
- st.markdown(
144
- f"**Màu {class_name}: Độ tin cậy {confidence:.2f}%**"
145
- )
146
- st.progress(float(confidence) / 100)
147
-
148
- st.markdown("</div>", unsafe_allow_html=True)
149
-
150
- # User Confirmation Section
151
- st.markdown("### Xác nhận độ chính xác của mô hình")
152
- st.write("Giúp chúng tôi cải thiện hình bằng cách xác nhận độ chính xác của dự đoán.")
153
-
154
- # Accuracy Radio Button
155
- accuracy_option = st.radio(
156
- "Dự đoán có chính xác không?",
157
- ["Chọn", "Chính xác", "Không chính xác"],
158
- index=0
159
- )
160
-
161
- if accuracy_option == "Không chính xác":
162
- # Input for correct grade
163
- correct_grade = st.selectbox(
164
- "Chọn màu đá đúng:",
165
- ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7'],
166
- index=None,
167
- placeholder="Chọn màu đúng"
168
- )
169
-
170
- # Upload button
171
- if st.button("Tải lên Hình ảnh để sửa chữa"):
172
- if correct_grade and st.session_state.uploaded_image:
173
- try:
174
- # Đọc tệp từ `st.session_state.uploaded_image`
175
- uploaded_file = st.session_state.uploaded_image
176
-
177
- # Kiểm tra nếu uploaded_file đã là PIL Image
178
- if isinstance(uploaded_file, Image.Image):
179
- resized_image = uploaded_file.resize((512, 512))
180
- else:
181
- # Nếu là file-like object, mở bằng Pillow
182
- uploaded_image = Image.open(uploaded_file)
183
- resized_image = uploaded_image.resize((512, 512))
184
-
185
- # Lưu tệp ảnh resize tạm thời
186
- temp_image_path = f"temp_image_{hash(uploaded_file.name) if hasattr(uploaded_file, 'name') else 'unknown'}.png"
187
- resized_image.save(temp_image_path)
188
-
189
- # Tải ảnh lên Cloudinary
190
- cloudinary_result = upload_to_cloudinary(temp_image_path, correct_grade)
191
-
192
- if isinstance(cloudinary_result, dict):
193
- st.success(f"Hình ảnh đã được tải lên thành công cho màu {correct_grade}")
194
- st.write(f"URL công khai: {cloudinary_result['upload_result']['secure_url']}")
195
- else:
196
- st.error(cloudinary_result)
197
-
198
- # Xóa tệp tạm
199
- os.remove(temp_image_path)
200
-
201
- except Exception as e:
202
- st.error(f"Tải lên thất bại: {str(e)}")
203
- else:
204
- st.warning("Vui lòng chọn màu đúng trước khi tải lên.")
205
- else:
206
- st.info("Tải lên hình ảnh để xem các dự đoán.")
207
-
208
  st.markdown("---")
209
- st.markdown("Tạo bởi ❤️ với Streamlit")
210
-
211
- def load_model_and_scaler():
212
- """Load the trained model and scaler"""
213
- try:
214
- model = tf.keras.models.load_model('mlp_model.h5')
215
- # Tải scaler đã lưu
216
- scaler = joblib.load('scaler.pkl')
217
- return model, scaler
218
- except Exception as e:
219
- st.error(f"Error loading model or scaler: {str(e)}")
220
- return None, None
221
-
222
- def color_histogram(image, bins=16):
223
- """
224
- Tính histogram màu cho ảnh RGB
225
-
226
- Args:
227
- image (np.ndarray): Ảnh đầu vào
228
- bins (int): Số lượng bins của histogram
229
-
230
- Returns:
231
- np.ndarray: Histogram màu được chuẩn hóa
232
- """
233
- # Kiểm tra và chuyển đổi ảnh
234
- if image is None or image.size == 0:
235
- raise ValueError("Ảnh không hợp lệ")
236
-
237
- # Đảm bảo ảnh ở dạng uint8
238
- if image.dtype != np.uint8:
239
- image = (image * 255).astype(np.uint8)
240
-
241
- # Tính histogram cho từng kênh màu
242
- hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
243
- hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
244
- hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
245
-
246
- # Chuẩn hóa histogram
247
- hist_r = hist_r / np.sum(hist_r) if np.sum(hist_r) > 0 else hist_r
248
- hist_g = hist_g / np.sum(hist_g) if np.sum(hist_g) > 0 else hist_g
249
- hist_b = hist_b / np.sum(hist_b) if np.sum(hist_b) > 0 else hist_b
250
-
251
- return np.concatenate([hist_r, hist_g, hist_b])
252
-
253
- def color_moments(image):
254
- """
255
- Tính các moment màu cho ảnh
256
-
257
- Args:
258
- image (np.ndarray): Ảnh đầu vào
259
-
260
- Returns:
261
- np.ndarray: Các moment màu
262
- """
263
- # Kiểm tra và chuyển đổi ảnh
264
- if image is None or image.size == 0:
265
- raise ValueError("Ảnh không hợp lệ")
266
-
267
- # Đảm bảo ảnh ở dạng float và chuẩn hóa
268
- img = image.astype(np.float32) / 255.0 if image.max() > 1 else image.astype(np.float32)
269
-
270
- moments = []
271
- for i in range(3): # Cho mỗi kênh màu
272
- channel = img[:,:,i]
273
-
274
- # Tính các moment
275
- mean = np.mean(channel)
276
- std = np.std(channel)
277
- skewness = np.mean(((channel - mean) / (std + 1e-8)) ** 3)
278
-
279
- moments.extend([mean, std, skewness])
280
-
281
- return np.array(moments)
282
-
283
- def dominant_color_descriptor(image, k=3):
284
- """
285
- Xác định các màu chính thống trị trong ảnh
286
-
287
- Args:
288
- image (np.ndarray): Ảnh đầu vào
289
- k (int): Số lượng màu chủ đạo
290
-
291
- Returns:
292
- np.ndarray: Các màu chủ đạo và tỷ lệ
293
- """
294
- # Kiểm tra và chuyển đổi ảnh
295
- if image is None or image.size == 0:
296
- raise ValueError("Ảnh không hợp lệ")
297
-
298
- # Đảm bảo ảnh ở dạng uint8
299
- if image.dtype != np.uint8:
300
- image = (image * 255).astype(np.uint8)
301
-
302
- # Reshape ảnh thành mảng pixel
303
- pixels = image.reshape(-1, 3)
304
-
305
- # Các tham số cho K-means
306
- criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
307
- flags = cv2.KMEANS_RANDOM_CENTERS
308
-
309
- try:
310
- # Thực hiện phân cụm K-means
311
- _, labels, centers = cv2.kmeans(
312
- pixels.astype(np.float32), k, None, criteria, 10, flags
313
- )
314
-
315
- # Tính toán số lượng và tỷ lệ của từng cụm
316
- unique, counts = np.unique(labels, return_counts=True)
317
- percentages = counts / len(labels)
318
-
319
- # Kết hợp các màu và tỷ lệ
320
- dominant_colors = centers.flatten()
321
- color_percentages = percentages
322
-
323
- return np.concatenate([dominant_colors, color_percentages])
324
- except Exception:
325
- # Trả về mảng 0 nếu có lỗi
326
- return np.zeros(2 * k)
327
-
328
- def color_coherence_vector(image, k=3):
329
- """
330
- Tính vector liên kết màu
331
-
332
- Args:
333
- image (np.ndarray): Ảnh đầu vào
334
- k (int): Số lượng vùng
335
-
336
- Returns:
337
- np.ndarray: Vector liên kết màu
338
- """
339
- # Kiểm tra và chuyển đổi ảnh
340
- if image is None or image.size == 0:
341
- raise ValueError("Ảnh không hợp lệ")
342
-
343
- # Chuyển sang ảnh xám
344
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
345
-
346
- # Đảm bảo ảnh ở dạng uint8
347
- if gray.dtype != np.uint8:
348
- gray = np.uint8(gray)
349
-
350
- # Áp dụng Otsu's thresholding
351
- _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
352
-
353
- # Phân tích thành phần liên thông
354
- num_labels, labels = cv2.connectedComponents(binary)
355
-
356
- ccv = []
357
- for i in range(1, min(k+1, num_labels)):
358
- region_mask = (labels == i)
359
- total_pixels = np.sum(region_mask)
360
- coherent_pixels = total_pixels
361
-
362
- ccv.extend([coherent_pixels, total_pixels])
363
-
364
- # Đảm bảo độ dài vector
365
- while len(ccv) < 2 * k:
366
- ccv.append(0)
367
-
368
- return np.array(ccv)
369
-
370
- def edge_features(image, bins=16):
371
- """
372
- Trích xuất đặc trưng cạnh từ ảnh
373
-
374
- Args:
375
- image (np.ndarray): Ảnh đầu vào
376
- bins (int): Số lượng bins của histogram
377
-
378
- Returns:
379
- np.ndarray: Đặc trưng cạnh
380
- """
381
- # Kiểm tra và chuyển đổi ảnh
382
- if image is None or image.size == 0:
383
- raise ValueError("Ảnh không hợp lệ")
384
-
385
- # Chuyển sang ảnh xám
386
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
387
-
388
- # Đảm bảo ảnh ở dạng uint8
389
- if gray.dtype != np.uint8:
390
- gray = np.uint8(gray)
391
-
392
- # Tính Sobel edges
393
- sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
394
- sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
395
- sobel_mag = np.sqrt(sobel_x**2 + sobel_y**2)
396
-
397
- # Chuẩn hóa độ lớn Sobel
398
- sobel_mag = np.uint8(255 * sobel_mag / np.max(sobel_mag))
399
-
400
- # Tính histogram của Sobel magnitude
401
- sobel_hist = cv2.calcHist([sobel_mag], [0], None, [bins], [0, 256]).flatten()
402
- sobel_hist = sobel_hist / np.sum(sobel_hist) if np.sum(sobel_hist) > 0 else sobel_hist
403
-
404
- # Tính mật độ cạnh bằng Canny
405
- canny_edges = cv2.Canny(gray, 100, 200)
406
- edge_density = np.sum(canny_edges) / (gray.shape[0] * gray.shape[1])
407
-
408
- return np.concatenate([sobel_hist, [edge_density]])
409
-
410
-
411
-
412
- def histogram_in_color_space(image, color_space='HSV', bins=16):
413
- """
414
- Tính histogram của ảnh trong một không gian màu mới.
415
- """
416
- if color_space == 'HSV':
417
- converted = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
418
- elif color_space == 'LAB':
419
- converted = cv2.cvtColor(image, cv2.COLOR_RGB2Lab)
420
- else:
421
- raise ValueError("Unsupported color space")
422
-
423
- histograms = []
424
- for i in range(3): # 3 kênh màu
425
- hist = cv2.calcHist([converted], [i], None, [bins], [0, 256]).flatten()
426
- hist = hist / np.sum(hist)
427
- histograms.append(hist)
428
-
429
- return np.concatenate(histograms)
430
-
431
- def glcm_features(image, distances=[1, 2, 3], angles=[0, np.pi/4, np.pi/2, 3*np.pi/4], levels=256):
432
- """
433
- Tính các đặc trưng GLCM của ảnh grayscale.
434
- """
435
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
436
-
437
- # Đảm bảo ảnh ở dạng uint8
438
- if gray.dtype != np.uint8:
439
- gray = (gray * 255).astype(np.uint8)
440
-
441
- glcm = graycomatrix(gray, distances=distances, angles=angles, levels=levels, symmetric=True, normed=True)
442
-
443
- features = []
444
- # Các thuộc tính phổ biến: contrast, homogeneity, energy, correlation
445
- for prop in ['contrast', 'homogeneity', 'energy', 'correlation']:
446
- features.extend(graycoprops(glcm, prop).flatten())
447
-
448
- return np.array(features)
449
-
450
-
451
- def gabor_features(image, kernels=None):
452
- """
453
- Tính các đặc trưng từ bộ lọc Gabor.
454
- """
455
- if kernels is None:
456
- kernels = []
457
- for theta in np.arange(0, np.pi, np.pi / 4): # Các góc từ 0 đến 180 độ
458
- for sigma in [1, 3]: # Các giá trị sigma
459
- for frequency in [0.1, 0.5]: # Các tần số
460
- kernel = cv2.getGaborKernel((9, 9), sigma, theta, 1/frequency, gamma=0.5, ktype=cv2.CV_32F)
461
- kernels.append(kernel)
462
-
463
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
464
- features = []
465
- for kernel in kernels:
466
- filtered = cv2.filter2D(gray, cv2.CV_32F, kernel)
467
- features.append(filtered.mean())
468
- features.append(filtered.var())
469
-
470
- return np.array(features)
471
-
472
- def wavelet_features(image, wavelet='db1', level=3):
473
- """
474
- Trích xuất các hệ số wavelet từ ảnh grayscale.
475
- """
476
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
477
- coeffs = pywt.wavedec2(gray, wavelet, level=level)
478
- features = []
479
- for coeff in coeffs:
480
- if isinstance(coeff, tuple): # Chi tiết (LH, HL, HH)
481
- for subband in coeff:
482
- features.append(subband.mean())
483
- features.append(subband.var())
484
- else: # Xấp xỉ (LL)
485
- features.append(coeff.mean())
486
- features.append(coeff.var())
487
-
488
- return np.array(features)
489
-
490
- def fractal_dimension(image):
491
- """
492
- Tính Fractal Dimension của ảnh.
493
- """
494
- # Chuyển đổi ảnh sang grayscale
495
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
496
-
497
- # Đảm bảo ảnh ở dạng uint8
498
- if gray.dtype != np.uint8:
499
- gray = (gray * 255).astype(np.uint8)
500
-
501
- # Áp dụng Canny để tìm cạnh
502
- edges = cv2.Canny(gray, 100, 200)
503
-
504
- # Tính fractal dimension dựa trên phương pháp box-counting
505
- sizes = []
506
- counts = []
507
- for size in range(2, 65, 2): # Kích thước hộp từ 2 đến 64
508
- region_size = (edges.shape[0] // size, edges.shape[1] // size)
509
- count = np.sum(cv2.resize(edges, region_size, interpolation=cv2.INTER_AREA) > 0)
510
- sizes.append(size)
511
- counts.append(count)
512
-
513
- # Tính log-log slope
514
- log_sizes = np.log(sizes)
515
- log_counts = np.log(counts)
516
- slope, _ = np.polyfit(log_sizes, log_counts, 1)
517
-
518
- # Trả về giá trị fractal dimension
519
- return np.array([slope])
520
-
521
-
522
- def extract_features(image):
523
- """
524
- Extract multiple features from an image, including edge-based features.
525
- """
526
- color_hist = color_histogram(image)
527
- color_mom = color_moments(image)
528
- dom_color = dominant_color_descriptor(image)
529
- ccv = color_coherence_vector(image)
530
- edges = edge_features(image)
531
-
532
- # Các đặc trưng từ phương pháp mới
533
- hsv_hist = histogram_in_color_space(image, color_space='HSV')
534
- # lab_hist = histogram_in_color_space(image, color_space='LAB')
535
- glcm = glcm_features(image)
536
- gabor = gabor_features(image)
537
- wavelet = wavelet_features(image)
538
- # fractal = fractal_dimension(image)
539
-
540
- # Kết hợp tất cả thành một vector đặc trưng
541
- return np.concatenate([
542
- color_hist,
543
- color_mom,
544
- dom_color,
545
- ccv,
546
- edges,
547
- hsv_hist,
548
- # lab_hist,
549
- glcm,
550
- gabor,
551
- wavelet,
552
- # fractal
553
- ])
554
-
555
-
556
- def create_resnet50_feature_extractor(input_shape=(256, 256, 3), num_classes=None):
557
- # Xây dựng mô hình ResNet112 đã huấn luyện sẵn từ TensorFlow
558
- inputs = layers.Input(shape=input_shape)
559
-
560
- # Thêm lớp Lambda để tiền xử lý ảnh
561
- x = Lambda(preprocess_input, output_shape=input_shape)(inputs) # Xử lý ảnh đầu vào
562
-
563
- # Sử dụng mô hình ResNet112 đã được huấn luyện sẵn
564
- resnet50_model = ResNet50(include_top=False, weights='imagenet', input_tensor=x)
565
-
566
- # Trích xuất đặc trưng từ mô hình ResNet112
567
- x = layers.GlobalAveragePooling2D()(resnet50_model.output)
568
-
569
- if num_classes:
570
- x = layers.Dense(num_classes, activation='softmax')(x) # Thêm lớp phân loại (nếu có)
571
-
572
- return models.Model(inputs=inputs, outputs=x)
573
-
574
- def extract_features(image):
575
- """
576
- Extract multiple features from an image, including edge-based features.
577
- """
578
- color_hist = color_histogram(image)
579
- color_mom = color_moments(image)
580
- dom_color = dominant_color_descriptor(image)
581
- ccv = color_coherence_vector(image)
582
- edges = edge_features(image)
583
-
584
- # Các đặc trưng từ phương pháp mới
585
- hsv_hist = histogram_in_color_space(image, color_space='HSV')
586
- # lab_hist = histogram_in_color_space(image, color_space='LAB')
587
- glcm = glcm_features(image)
588
- gabor = gabor_features(image)
589
- wavelet = wavelet_features(image)
590
- # fractal = fractal_dimension(image)
591
-
592
- # Kết hợp tất cả thành một vector đặc trưng
593
- return np.concatenate([
594
- color_hist,
595
- color_mom,
596
- dom_color,
597
- ccv,
598
- edges,
599
- hsv_hist,
600
- # lab_hist,
601
- glcm,
602
- gabor,
603
- wavelet,
604
- # fractal
605
- ])
606
-
607
- def preprocess_image(image, scaler):
608
- image=np.array(image)
609
- img_size=(256, 256)
610
- img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
611
- img = cv2.resize(img, img_size)
612
- img_array = img.astype('float32') / 255.0
613
-
614
- features1 = np.array(extract_features(img_array))
615
- resnet_extractor = create_resnet50_feature_extractor()
616
- features2 = resnet_extractor.predict(np.expand_dims(img_array, axis=0))
617
- print(f"Shape of features1: {features1.shape}")
618
- print(f"Shape of features2: {features2.shape}")
619
- features = np.concatenate([np.expand_dims(features1, axis=0), features2], axis=1) # Concatenate along axis 0
620
-
621
- # Scale features using the provided scaler
622
- scaled_features = scaler.transform(features) # Reshape for scaling
623
-
624
- return scaled_features
625
-
626
- def get_top_predictions(prediction, class_names):
627
- # Extract the top 5 predictions with confidence values
628
- probabilities = tf.nn.softmax(prediction[0]).numpy()
629
- top_indices = np.argsort(probabilities)[-5:][::-1]
630
- return [(class_names[i], probabilities[i] * 100) for i in top_indices]
631
 
632
  if __name__ == "__main__":
633
  main()
 
2
  import tensorflow as tf
3
  import numpy as np
4
  import cv2
 
5
  from PIL import Image
 
 
 
 
6
  import io
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Set page config
9
+ st.set_page_config(
10
+ page_title="Stone Classification",
11
+ page_icon="🪨",
12
+ layout="wide"
 
 
 
 
 
 
13
  )
14
 
15
+ # Custom CSS to improve the appearance
16
+ st.markdown("""
17
+ <style>
18
+ .main {
19
+ padding: 2rem;
20
+ }
21
+ .stButton>button {
22
+ width: 100%;
23
+ margin-top: 1rem;
24
+ }
25
+ .upload-text {
26
+ text-align: center;
27
+ padding: 2rem;
28
+ }
29
+ .prediction-card {
30
+ padding: 2rem;
31
+ border-radius: 0.5rem;
32
+ background-color: #f0f2f6;
33
+ margin: 1rem 0;
34
+ }
35
+ .top-predictions {
36
+ margin-top: 2rem;
37
+ padding: 1rem;
38
+ background-color: white;
39
+ border-radius: 0.5rem;
40
+ box-shadow: 0 1px 3px rgba(0,0,0,0.12);
41
+ }
42
+ .prediction-bar {
43
+ display: flex;
44
+ align-items: center;
45
+ margin: 0.5rem 0;
46
+ }
47
+ .prediction-label {
48
+ width: 100px;
49
+ font-weight: 500;
50
+ }
51
+ </style>
52
+ """, unsafe_allow_html=True)
53
+
54
+ @st.cache_resource
55
+ def load_model():
56
+ """Load the trained model"""
57
+ return tf.keras.models.load_model('custom_model.h5')
58
+
59
+ def preprocess_image(image):
60
+ """Preprocess the uploaded image"""
61
+ # # Convert to RGB if needed
62
+ # if image.mode != 'RGB':
63
+ # image = image.convert('RGB')
64
+
65
+ # Convert to numpy array
66
+ img_array = np.array(image)
67
+
68
+ # # Convert to RGB if needed
69
+ # if len(img_array.shape) == 2: # Grayscale
70
+ # img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
71
+ # elif img_array.shape[2] == 4: # RGBA
72
+ # img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
73
+
74
+ # # Preprocess image similar to training
75
+ # img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV)
76
+ # img_hsv[:, :, 2] = cv2.equalizeHist(img_hsv[:, :, 2])
77
+ # img_array = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
78
+
79
+ # # Adjust brightness
80
+ # target_brightness = 150
81
+ # current_brightness = np.mean(img_array)
82
+ # alpha = target_brightness / (current_brightness + 1e-5)
83
+ # img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=0)
84
+
85
+ # # Apply Gaussian blur
86
+ # img_array = cv2.GaussianBlur(img_array, (5, 5), 0)
87
+
88
+ # Resize
89
+ img_array = cv2.resize(img_array, (256, 256))
90
+
91
+ # Normalize
92
+ img_array = img_array.astype('float32') / 255.0
93
+
94
+ return img_array
95
+
96
+ def get_top_predictions(prediction, class_names, top_k=5):
97
+ """Get top k predictions with their probabilities"""
98
+ # Get indices of top k predictions
99
+ top_indices = prediction.argsort()[0][-top_k:][::-1]
100
+
101
+ # Get corresponding class names and probabilities
102
+ top_predictions = [
103
+ (class_names[i], float(prediction[0][i]) * 100)
104
+ for i in top_indices
105
+ ]
106
+
107
+ return top_predictions
108
 
109
  def main():
110
+ # Title
111
+ st.title("🪨 Stone Classification")
112
+ st.write("Upload an image of a stone to classify its type")
113
+
114
+ # Initialize session state for prediction if not exists
 
 
 
 
 
115
  if 'predictions' not in st.session_state:
116
  st.session_state.predictions = None
117
+
118
+ # Create two columns
 
119
  col1, col2 = st.columns(2)
120
+
121
  with col1:
122
+ st.subheader("Upload Image")
123
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
124
+
125
  if uploaded_file is not None:
126
+ # Display uploaded image
127
+ image = Image.open(uploaded_file)
128
+ st.image(image, caption="Uploaded Image", use_column_width=True)
129
+
130
+ with st.spinner('Analyzing image...'):
131
+ try:
132
+ # Load model
133
+ model = load_model()
134
+
135
+ # Preprocess image
136
+ processed_image = preprocess_image(image)
137
+
138
+ # Make prediction
139
+ prediction = model.predict(np.expand_dims(processed_image, axis=0))
140
  class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
141
+
142
+ # Get top 5 predictions
143
+ top_predictions = get_top_predictions(prediction, class_names)
144
+
145
+ # Store in session state
146
+ st.session_state.predictions = top_predictions
147
+
148
+ except Exception as e:
149
+ st.error(f"Error during prediction: {str(e)}")
150
+
151
  with col2:
152
+ st.subheader("Prediction Results")
153
+ if st.session_state.predictions is not None:
154
+ # Create a card-like container for results
155
+ results_container = st.container()
156
+ with results_container:
157
+ # Display main prediction
158
+ st.markdown("<div class='prediction-card'>", unsafe_allow_html=True)
159
+ top_class, top_confidence = st.session_state.predictions[0]
160
+ st.markdown(f"### Primary Prediction: Grade {top_class}")
161
+ st.markdown(f"### Confidence: {top_confidence:.2f}%")
162
+ st.markdown("</div>", unsafe_allow_html=True)
163
+
164
+ # Display confidence bar for top prediction
165
+ st.progress(top_confidence / 100)
166
+
167
+ # Display top 5 predictions
168
+ st.markdown("### Top 5 Predictions")
169
+ st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
170
+
171
+ # Create a Streamlit container for the predictions
172
+ for class_name, confidence in st.session_state.predictions:
173
+ col_label, col_bar, col_value = st.columns([2, 6, 2])
174
+ with col_label:
175
+ st.write(f"Grade {class_name}")
176
+ with col_bar:
177
+ st.progress(confidence / 100)
178
+ with col_value:
179
+ st.write(f"{confidence:.2f}%")
180
+
181
+ st.markdown("</div>", unsafe_allow_html=True)
182
+ else:
183
+ st.info("Upload an image and click 'Predict' to see the results")
184
+
185
+ # Footer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  st.markdown("---")
187
+ st.markdown("Made with ❤️ using Streamlit")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  if __name__ == "__main__":
190
  main()
requirements.txt CHANGED
@@ -1,12 +1,4 @@
1
  streamlit
2
  tensorflow
3
  opencv-python
4
- pillow
5
- scikit-learn
6
- matplotlib
7
- transformers
8
- torch
9
- torchvision
10
- scikit-image
11
- PyWavelets
12
- cloudinary
 
1
  streamlit
2
  tensorflow
3
  opencv-python
4
+ pillow