SonFox2920 commited on
Commit
699b5d2
·
verified ·
1 Parent(s): d4824f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +458 -19
app.py CHANGED
@@ -1,16 +1,35 @@
1
- import streamlit as st
2
- import tensorflow as tf
3
- import numpy as np
 
 
4
  import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from PIL import Image
6
- import io
7
- import torch
8
  import cloudinary
9
  import cloudinary.uploader
10
  from cloudinary.utils import cloudinary_url
11
- import os
12
- import random
13
- import string
14
  # Cloudinary Configuration
15
  cloudinary.config(
16
  cloud_name = os.getenv("CLOUD"),
@@ -93,10 +112,396 @@ def resize_to_square(image):
93
  # Copy the image into center of result image
94
  new_img[y_center:y_center+image.shape[0],
95
  x_center:x_center+image.shape[1]] = image
96
-
97
  return new_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  @st.cache_resource
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def load_models():
101
  """Load both object detection and classification models"""
102
  # Load object detection model
@@ -106,10 +511,28 @@ def load_models():
106
  object_detection_model.eval()
107
 
108
  # Load classification model
109
- classification_model = tf.keras.models.load_model('custom_model.h5')
110
 
111
  return object_detection_model, classification_model, device
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def perform_object_detection(image, model, device):
114
  original_size = image.size
115
  target_size = (256, 256)
@@ -164,10 +587,28 @@ def perform_object_detection(image, model, device):
164
 
165
  def preprocess_image(image):
166
  """Preprocess the image for classification"""
 
167
  img_array = np.array(image)
168
  img_array = cv2.resize(img_array, (256, 256))
169
- img_array = img_array.astype('float32') / 255.0
170
- return img_array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  def get_top_predictions(prediction, class_names, top_k=5):
173
  """Get top k predictions with their probabilities"""
@@ -218,13 +659,11 @@ def main():
218
  all_predictions = []
219
  all_image=[]
220
  for idx, cropped_image in cropped_images:
221
- processed_image = preprocess_image(cropped_image)
222
- prediction = classification_model.predict(
223
- np.expand_dims(processed_image, axis=0)
224
- )
225
- top_predictions = get_top_predictions(prediction, class_names)
226
- all_predictions.append([idx,top_predictions])
227
- all_image.append(cropped_image)
228
  # Store in session state
229
  st.session_state.predictions = all_predictions
230
  st.session_state.image = all_image
 
1
+ import os
2
+ import random
3
+ import string
4
+ import io
5
+ import joblib
6
  import cv2
7
+ import numpy as np
8
+ import tensorflow as tf
9
+ from tensorflow import keras
10
+ from tensorflow.keras import layers, models
11
+ from tensorflow.keras.applications import ResNet50, EfficientNetB0
12
+ from tensorflow.keras.applications.resnet import preprocess_input
13
+ from tensorflow.keras.applications.efficientnet import preprocess_input
14
+ from tensorflow.keras.layers import Lambda # Đảm bảo nhập Lambda từ tensorflow.keras.layers
15
+ from keras.applications import ResNet50
16
+ from sklearn.preprocessing import StandardScaler
17
+ from sklearn.metrics import (
18
+ confusion_matrix,
19
+ ConfusionMatrixDisplay,
20
+ accuracy_score,
21
+ precision_score,
22
+ recall_score,
23
+ f1_score
24
+ )
25
+ import matplotlib.pyplot as plt
26
+ from skimage.feature import graycomatrix, graycoprops
27
  from PIL import Image
28
+ import streamlit as st
 
29
  import cloudinary
30
  import cloudinary.uploader
31
  from cloudinary.utils import cloudinary_url
32
+ import torch
 
 
33
  # Cloudinary Configuration
34
  cloudinary.config(
35
  cloud_name = os.getenv("CLOUD"),
 
112
  # Copy the image into center of result image
113
  new_img[y_center:y_center+image.shape[0],
114
  x_center:x_center+image.shape[1]] = image
 
115
  return new_img
116
+ def color_histogram(image, bins=16):
117
+ """
118
+ Tính histogram màu cho ảnh RGB
119
+
120
+ Args:
121
+ image (np.ndarray): Ảnh đầu vào
122
+ bins (int): Số lượng bins của histogram
123
+
124
+ Returns:
125
+ np.ndarray: Histogram màu được chuẩn hóa
126
+ """
127
+ # Kiểm tra và chuyển đổi ảnh
128
+ if image is None or image.size == 0:
129
+ raise ValueError("Ảnh không hợp lệ")
130
+
131
+ # Đảm bảo ảnh ở dạng uint8
132
+ if image.dtype != np.uint8:
133
+ image = (image * 255).astype(np.uint8)
134
+
135
+ # Tính histogram cho từng kênh màu
136
+ hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
137
+ hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
138
+ hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
139
+
140
+ # Chuẩn hóa histogram
141
+ hist_r = hist_r / np.sum(hist_r) if np.sum(hist_r) > 0 else hist_r
142
+ hist_g = hist_g / np.sum(hist_g) if np.sum(hist_g) > 0 else hist_g
143
+ hist_b = hist_b / np.sum(hist_b) if np.sum(hist_b) > 0 else hist_b
144
+
145
+ return np.concatenate([hist_r, hist_g, hist_b])
146
+
147
+ def color_moments(image):
148
+ """
149
+ Tính các moment màu cho ảnh
150
+
151
+ Args:
152
+ image (np.ndarray): Ảnh đầu vào
153
+
154
+ Returns:
155
+ np.ndarray: Các moment màu
156
+ """
157
+ # Kiểm tra và chuyển đổi ảnh
158
+ if image is None or image.size == 0:
159
+ raise ValueError("Ảnh không hợp lệ")
160
+
161
+ # Đảm bảo ảnh ở dạng float và chuẩn hóa
162
+ img = image.astype(np.float32) / 255.0 if image.max() > 1 else image.astype(np.float32)
163
+
164
+ moments = []
165
+ for i in range(3): # Cho mỗi kênh màu
166
+ channel = img[:,:,i]
167
+
168
+ # Tính các moment
169
+ mean = np.mean(channel)
170
+ std = np.std(channel)
171
+ skewness = np.mean(((channel - mean) / (std + 1e-8)) ** 3)
172
+
173
+ moments.extend([mean, std, skewness])
174
+
175
+ return np.array(moments)
176
+
177
+ def dominant_color_descriptor(image, k=3):
178
+ """
179
+ Xác định các màu chính thống trị trong ảnh
180
+
181
+ Args:
182
+ image (np.ndarray): Ảnh đầu vào
183
+ k (int): Số lượng màu chủ đạo
184
+
185
+ Returns:
186
+ np.ndarray: Các màu chủ đạo và tỷ lệ
187
+ """
188
+ # Kiểm tra và chuyển đổi ảnh
189
+ if image is None or image.size == 0:
190
+ raise ValueError("Ảnh không hợp lệ")
191
+
192
+ # Đảm bảo ảnh ở dạng uint8
193
+ if image.dtype != np.uint8:
194
+ image = (image * 255).astype(np.uint8)
195
+
196
+ # Reshape ảnh thành mảng pixel
197
+ pixels = image.reshape(-1, 3)
198
+
199
+ # Các tham số cho K-means
200
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
201
+ flags = cv2.KMEANS_RANDOM_CENTERS
202
+
203
+ try:
204
+ # Thực hiện phân cụm K-means
205
+ _, labels, centers = cv2.kmeans(
206
+ pixels.astype(np.float32), k, None, criteria, 10, flags
207
+ )
208
+
209
+ # Tính toán số lượng và tỷ lệ của từng cụm
210
+ unique, counts = np.unique(labels, return_counts=True)
211
+ percentages = counts / len(labels)
212
+
213
+ # Kết hợp các màu và tỷ lệ
214
+ dominant_colors = centers.flatten()
215
+ color_percentages = percentages
216
+
217
+ return np.concatenate([dominant_colors, color_percentages])
218
+ except Exception:
219
+ # Trả về mảng 0 nếu có lỗi
220
+ return np.zeros(2 * k)
221
+
222
+ def color_coherence_vector(image, k=3):
223
+ """
224
+ Tính vector liên kết màu
225
+
226
+ Args:
227
+ image (np.ndarray): Ảnh đầu vào
228
+ k (int): Số lượng vùng
229
+
230
+ Returns:
231
+ np.ndarray: Vector liên kết màu
232
+ """
233
+ # Kiểm tra và chuyển đổi ảnh
234
+ if image is None or image.size == 0:
235
+ raise ValueError("Ảnh không hợp lệ")
236
+
237
+ # Chuyển sang ảnh xám
238
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
239
+
240
+ # Đảm bảo ảnh ở dạng uint8
241
+ if gray.dtype != np.uint8:
242
+ gray = np.uint8(gray)
243
+
244
+ # Áp dụng Otsu's thresholding
245
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
246
+
247
+ # Phân tích thành phần liên thông
248
+ num_labels, labels = cv2.connectedComponents(binary)
249
+
250
+ ccv = []
251
+ for i in range(1, min(k+1, num_labels)):
252
+ region_mask = (labels == i)
253
+ total_pixels = np.sum(region_mask)
254
+ coherent_pixels = total_pixels
255
+
256
+ ccv.extend([coherent_pixels, total_pixels])
257
+
258
+ # Đảm bảo độ dài vector
259
+ while len(ccv) < 2 * k:
260
+ ccv.append(0)
261
+
262
+ return np.array(ccv)
263
+
264
+ def edge_features(image, bins=16):
265
+ """
266
+ Trích xuất đặc trưng cạnh từ ảnh
267
+
268
+ Args:
269
+ image (np.ndarray): Ảnh đầu vào
270
+ bins (int): Số lượng bins của histogram
271
+
272
+ Returns:
273
+ np.ndarray: Đặc trưng cạnh
274
+ """
275
+ # Kiểm tra và chuyển đổi ảnh
276
+ if image is None or image.size == 0:
277
+ raise ValueError("Ảnh không hợp lệ")
278
+
279
+ # Chuyển sang ảnh xám
280
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
281
+
282
+ # Đảm bảo ảnh ở dạng uint8
283
+ if gray.dtype != np.uint8:
284
+ gray = np.uint8(gray)
285
+
286
+ # Tính Sobel edges
287
+ sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
288
+ sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
289
+ sobel_mag = np.sqrt(sobel_x**2 + sobel_y**2)
290
+
291
+ # Chuẩn hóa độ lớn Sobel
292
+ sobel_mag = np.uint8(255 * sobel_mag / np.max(sobel_mag))
293
+
294
+ # Tính histogram của Sobel magnitude
295
+ sobel_hist = cv2.calcHist([sobel_mag], [0], None, [bins], [0, 256]).flatten()
296
+ sobel_hist = sobel_hist / np.sum(sobel_hist) if np.sum(sobel_hist) > 0 else sobel_hist
297
+
298
+ # Tính mật độ cạnh bằng Canny
299
+ canny_edges = cv2.Canny(gray, 100, 200)
300
+ edge_density = np.sum(canny_edges) / (gray.shape[0] * gray.shape[1])
301
+
302
+ return np.concatenate([sobel_hist, [edge_density]])
303
+
304
+
305
+ import pywt # Thư viện xử lý wavelet
306
 
307
+ def histogram_in_color_space(image, color_space='HSV', bins=16):
308
+ """
309
+ Tính histogram của ảnh trong một không gian màu mới.
310
+ """
311
+ if color_space == 'HSV':
312
+ converted = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
313
+ elif color_space == 'LAB':
314
+ converted = cv2.cvtColor(image, cv2.COLOR_RGB2Lab)
315
+ else:
316
+ raise ValueError("Unsupported color space")
317
+
318
+ histograms = []
319
+ for i in range(3): # 3 kênh màu
320
+ hist = cv2.calcHist([converted], [i], None, [bins], [0, 256]).flatten()
321
+ hist = hist / np.sum(hist)
322
+ histograms.append(hist)
323
+
324
+ return np.concatenate(histograms)
325
+
326
+ def glcm_features(image, distances=[1, 2, 3], angles=[0, np.pi/4, np.pi/2, 3*np.pi/4], levels=256):
327
+ """
328
+ Tính các đặc trưng GLCM của ảnh grayscale.
329
+ """
330
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
331
+
332
+ # Đảm bảo ảnh ở dạng uint8
333
+ if gray.dtype != np.uint8:
334
+ gray = (gray * 255).astype(np.uint8)
335
+
336
+ glcm = graycomatrix(gray, distances=distances, angles=angles, levels=levels, symmetric=True, normed=True)
337
+
338
+ features = []
339
+ # Các thuộc tính phổ biến: contrast, homogeneity, energy, correlation
340
+ for prop in ['contrast', 'homogeneity', 'energy', 'correlation']:
341
+ features.extend(graycoprops(glcm, prop).flatten())
342
+
343
+ return np.array(features)
344
+
345
+ def gabor_features(image, kernels=None):
346
+ """
347
+ Tính các đặc trưng từ bộ lọc Gabor.
348
+ """
349
+ if kernels is None:
350
+ kernels = []
351
+ for theta in np.arange(0, np.pi, np.pi / 4): # Các góc từ 0 đến 180 độ
352
+ for sigma in [1, 3]: # Các giá trị sigma
353
+ for frequency in [0.1, 0.5]: # Các tần số
354
+ kernel = cv2.getGaborKernel((9, 9), sigma, theta, 1/frequency, gamma=0.5, ktype=cv2.CV_32F)
355
+ kernels.append(kernel)
356
+
357
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
358
+ features = []
359
+ for kernel in kernels:
360
+ filtered = cv2.filter2D(gray, cv2.CV_32F, kernel)
361
+ features.append(filtered.mean())
362
+ features.append(filtered.var())
363
+
364
+ return np.array(features)
365
+
366
+ def wavelet_features(image, wavelet='db1', level=3):
367
+ """
368
+ Trích xuất các hệ số wavelet từ ảnh grayscale.
369
+ """
370
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
371
+ coeffs = pywt.wavedec2(gray, wavelet, level=level)
372
+ features = []
373
+ for coeff in coeffs:
374
+ if isinstance(coeff, tuple): # Chi tiết (LH, HL, HH)
375
+ for subband in coeff:
376
+ features.append(subband.mean())
377
+ features.append(subband.var())
378
+ else: # Xấp xỉ (LL)
379
+ features.append(coeff.mean())
380
+ features.append(coeff.var())
381
+
382
+ return np.array(features)
383
+
384
+ from skimage.feature import local_binary_pattern
385
+ from skimage.color import rgb2gray
386
+ from skimage.measure import shannon_entropy
387
+ from skimage.feature import hog
388
+
389
+ def illumination_features(image):
390
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
391
+ mean_brightness = np.mean(gray)
392
+ contrast = np.std(gray)
393
+ return np.array([mean_brightness, contrast])
394
+
395
+ def saturation_index(image):
396
+ hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
397
+ s_channel = hsv[:, :, 1]
398
+ mean_saturation = np.mean(s_channel)
399
+ std_saturation = np.std(s_channel)
400
+ return np.array([mean_saturation, std_saturation])
401
+
402
+ def local_binary_pattern_features(image, num_points=24, radius=3):
403
+ gray = rgb2gray(image)
404
+ lbp = local_binary_pattern(gray, num_points, radius, method="uniform")
405
+ hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, num_points + 3), range=(0, num_points + 2))
406
+ hist = hist / np.sum(hist)
407
+ return hist
408
+
409
+ def fourier_transform_features(image):
410
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
411
+ f_transform = np.fft.fft2(gray)
412
+ f_shift = np.fft.fftshift(f_transform)
413
+ magnitude_spectrum = 20 * np.log(np.abs(f_shift) + 1)
414
+ mean_frequency = np.mean(magnitude_spectrum)
415
+ std_frequency = np.std(magnitude_spectrum)
416
+ return np.array([mean_frequency, std_frequency])
417
+
418
+ def fractal_dimension(image):
419
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
420
+ size = gray.shape[0] * gray.shape[1]
421
+ edges = cv2.Canny(gray, 100, 200)
422
+ count = np.sum(edges > 0)
423
+ fractal_dim = np.log(count + 1) / np.log(size)
424
+ return np.array([fractal_dim])
425
+
426
+
427
+ def glossiness_index(image):
428
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
429
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
430
+ glossiness = np.mean(gray[binary == 255])
431
+ return np.array([glossiness])
432
+
433
+ def histogram_oriented_gradients(image):
434
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
435
+ features, _ = hog(gray, orientations=9, pixels_per_cell=(8, 8), cells_per_block=(2, 2), visualize=True)
436
+ return features
437
+
438
+ def color_entropy(image):
439
+ entropy = shannon_entropy(image)
440
+ return np.array([entropy])
441
+
442
+ def spatial_color_distribution(image, grid_size=4):
443
+ h, w, _ = image.shape
444
+ features = []
445
+ for i in range(grid_size):
446
+ for j in range(grid_size):
447
+ x_start = i * h // grid_size
448
+ x_end = (i + 1) * h // grid_size
449
+ y_start = j * w // grid_size
450
+ y_end = (j + 1) * w // grid_size
451
+ patch = image[x_start:x_end, y_start:y_end]
452
+ hist = cv2.calcHist([patch], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]).flatten()
453
+ hist = hist / np.sum(hist)
454
+ features.extend(hist)
455
+ return np.array(features)
456
+
457
+ def uniform_region_features(image):
458
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
459
+ num_labels, labels = cv2.connectedComponents(gray)
460
+ unique, counts = np.unique(labels, return_counts=True)
461
+ uniformity = np.sum((counts / np.sum(counts)) ** 2)
462
+ return np.array([uniformity])
463
+
464
+ def color_space_features(image):
465
+ ycbcr = cv2.cvtColor(image, cv2.COLOR_RGB2YCrCb)
466
+ ycbcr_hist = cv2.calcHist([ycbcr], [1, 2], None, [16, 16], [0, 256, 0, 256]).flatten()
467
+ lab = cv2.cvtColor(image, cv2.COLOR_RGB2Lab)
468
+ lab_hist = cv2.calcHist([lab], [1, 2], None, [16, 16], [0, 256, 0, 256]).flatten()
469
+ ycbcr_hist = ycbcr_hist / np.sum(ycbcr_hist)
470
+ lab_hist = lab_hist / np.sum(lab_hist)
471
+ return np.concatenate([ycbcr_hist, lab_hist])
472
  @st.cache_resource
473
+ def extract_features(image):
474
+ color_hist = color_histogram(image)
475
+ color_mom = color_moments(image)
476
+ dom_color = dominant_color_descriptor(image)
477
+ ccv = color_coherence_vector(image)
478
+ edges = edge_features(image)
479
+ hsv_hist = histogram_in_color_space(image, color_space='HSV')
480
+ glcm = glcm_features(image)
481
+ gabor = gabor_features(image)
482
+ wavelet = wavelet_features(image)
483
+ illumination = illumination_features(image)
484
+ saturation = saturation_index(image)
485
+ lbp = local_binary_pattern_features(image)
486
+ fourier = fourier_transform_features(image)
487
+ fractal = fractal_dimension(image)
488
+ return np.concatenate([
489
+ color_hist,
490
+ color_mom,
491
+ dom_color,
492
+ ccv,
493
+ edges,
494
+ hsv_hist,
495
+ glcm,
496
+ gabor,
497
+ wavelet,
498
+ illumination,
499
+ saturation,
500
+ lbp,
501
+ fourier,
502
+ fractal,
503
+ ])
504
+
505
  def load_models():
506
  """Load both object detection and classification models"""
507
  # Load object detection model
 
511
  object_detection_model.eval()
512
 
513
  # Load classification model
514
+ classification_model = tf.keras.models.load_model('mlp_model.h5')
515
 
516
  return object_detection_model, classification_model, device
517
 
518
+ def create_efficientnetb0_feature_extractor(input_shape=(256, 256, 3), num_classes=None):
519
+ # Xây dựng mô hình EfficientNetB0 đã huấn luyện sẵn từ TensorFlow
520
+ inputs = layers.Input(shape=input_shape)
521
+
522
+ # Thêm lớp Lambda để tiền xử lý ảnh
523
+ x = Lambda(preprocess_input, output_shape=input_shape)(inputs) # Xử lý ảnh đầu vào
524
+
525
+ # Sử dụng mô hình EfficientNetB0 đã được huấn luyện sẵn
526
+ efficientnetb0_model = EfficientNetB0(include_top=False, weights='imagenet', input_tensor=x)
527
+
528
+ # Trích xuất đặc trưng từ mô hình EfficientNetB0
529
+ x = layers.GlobalAveragePooling2D()(efficientnetb0_model.output)
530
+
531
+ if num_classes:
532
+ x = layers.Dense(num_classes, activation='softmax')(x) # Thêm lớp phân loại (nếu có)
533
+
534
+ return models.Model(inputs=inputs, outputs=x)
535
+
536
  def perform_object_detection(image, model, device):
537
  original_size = image.size
538
  target_size = (256, 256)
 
587
 
588
  def preprocess_image(image):
589
  """Preprocess the image for classification"""
590
+ # Convert image to numpy array and resize
591
  img_array = np.array(image)
592
  img_array = cv2.resize(img_array, (256, 256))
593
+
594
+ # Extract custom features (ensure this returns a 1D array)
595
+ features = extract_features(img_array)
596
+ features = features.flatten() # Ensure 1D
597
+
598
+ # Extract EfficientNet features
599
+ model_extractor = create_efficientnetb0_feature_extractor()
600
+ model_features = model_extractor.predict(np.expand_dims(img_array, axis=0))
601
+ model_features = model_features.flatten() # Convert to 1D array
602
+
603
+ # Combine features
604
+ features_combined = np.concatenate([features, model_features])
605
+ features_combined = features_combined.reshape(1, -1) # Reshape to 2D for scaler
606
+
607
+ # Load and apply scaler
608
+ scaler = joblib.load('scaler.pkl')
609
+ processed_image = scaler.transform(features_combined)
610
+
611
+ return processed_image
612
 
613
  def get_top_predictions(prediction, class_names, top_k=5):
614
  """Get top k predictions with their probabilities"""
 
659
  all_predictions = []
660
  all_image=[]
661
  for idx, cropped_image in cropped_images:
662
+ processed_image = preprocess_image(cropped_image)
663
+ prediction = classification_model.predict(processed_image)
664
+ top_predictions = get_top_predictions(prediction, class_names)
665
+ all_predictions.append([idx,top_predictions])
666
+ all_image.append(cropped_image)
 
 
667
  # Store in session state
668
  st.session_state.predictions = all_predictions
669
  st.session_state.image = all_image