SonFox2920 commited on
Commit
bf836f1
·
verified ·
1 Parent(s): 2e8a520

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -6
app.py CHANGED
@@ -4,7 +4,20 @@ import numpy as np
4
  import cv2
5
  from PIL import Image
6
  import io
7
-
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # Set page config
9
  st.set_page_config(
10
  page_title="Stone Classification",
@@ -54,15 +67,127 @@ st.markdown("""
54
  @st.cache_resource
55
  def load_model():
56
  """Load the trained model"""
57
- return tf.keras.models.load_model('custom_model.h5')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def preprocess_image(image):
60
  """Preprocess the uploaded image"""
61
  # # Convert to RGB if needed
62
  # if image.mode != 'RGB':
63
  # image = image.convert('RGB')
64
-
65
- # Convert to numpy array
66
  img_array = np.array(image)
67
 
68
  # # Convert to RGB if needed
@@ -90,8 +215,15 @@ def preprocess_image(image):
90
 
91
  # Normalize
92
  img_array = img_array.astype('float32') / 255.0
93
-
94
- return img_array
 
 
 
 
 
 
 
95
 
96
  def get_top_predictions(prediction, class_names, top_k=5):
97
  """Get top k predictions with their probabilities"""
 
4
  import cv2
5
  from PIL import Image
6
  import io
7
+ import os
8
+ import cv2
9
+ import numpy as np
10
+ from tensorflow import keras
11
+ from tensorflow.keras import layers, models
12
+ from sklearn.preprocessing import StandardScaler
13
+ from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
14
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
15
+ import matplotlib.pyplot as plt
16
+ import random
17
+ from tensorflow.keras import layers, models
18
+ from tensorflow.keras.applications import EfficientNetB0
19
+ from tensorflow.keras.applications.efficientnet import preprocess_input
20
+ from tensorflow.keras.layers import Lambda # Đảm bảo nhập Lambda từ tensorflow.keras.layers
21
  # Set page config
22
  st.set_page_config(
23
  page_title="Stone Classification",
 
67
  @st.cache_resource
68
  def load_model():
69
  """Load the trained model"""
70
+ return tf.keras.models.load_model('mlp_model.h5')
71
+
72
+ def color_histogram(image, bins=16):
73
+ # (Previous implementation remains the same)
74
+ hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
75
+ hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
76
+ hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
77
+
78
+ hist_r = hist_r / np.sum(hist_r)
79
+ hist_g = hist_g / np.sum(hist_g)
80
+ hist_b = hist_b / np.sum(hist_b)
81
+
82
+ return np.concatenate([hist_r, hist_g, hist_b])
83
+
84
+ def color_moments(image):
85
+ # (Previous implementation remains the same)
86
+ img = image.astype(np.float32) / 255.0
87
+
88
+ moments = []
89
+ for i in range(3): # For each color channel
90
+ channel = img[:,:,i]
91
+
92
+ mean = np.mean(channel)
93
+ std = np.std(channel)
94
+ skewness = np.mean(((channel - mean) / std) ** 3)
95
+
96
+ moments.extend([mean, std, skewness])
97
+
98
+ return np.array(moments)
99
+
100
+ def dominant_color_descriptor(image, k=3):
101
+ # (Previous implementation remains the same)
102
+ pixels = image.reshape(-1, 3)
103
+
104
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
105
+ flags = cv2.KMEANS_RANDOM_CENTERS
106
+
107
+ try:
108
+ _, labels, centers = cv2.kmeans(pixels.astype(np.float32), k, None, criteria, 10, flags)
109
+
110
+ unique, counts = np.unique(labels, return_counts=True)
111
+ percentages = counts / len(labels)
112
+
113
+ dominant_colors = centers.flatten()
114
+ color_percentages = percentages
115
+
116
+ return np.concatenate([dominant_colors, color_percentages])
117
+ except:
118
+ return np.zeros(2 * k)
119
 
120
+ def color_coherence_vector(image, k=3):
121
+ # Convert to grayscale
122
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
123
+
124
+ # Convert the grayscale image to 8-bit format before applying threshold
125
+ gray = np.uint8(gray)
126
+
127
+ # Apply Otsu's thresholding method
128
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
129
+
130
+ # Perform connected components analysis
131
+ num_labels, labels = cv2.connectedComponents(binary)
132
+
133
+ ccv = []
134
+ for i in range(1, min(k+1, num_labels)):
135
+ region_mask = (labels == i)
136
+ total_pixels = np.sum(region_mask)
137
+ coherent_pixels = total_pixels
138
+
139
+ ccv.extend([coherent_pixels, total_pixels])
140
+
141
+ while len(ccv) < 2 * k:
142
+ ccv.append(0)
143
+
144
+ return np.array(ccv)
145
+
146
+
147
+ # ViT and Feature Extraction Functions (from previous implementation)
148
+ # (Keeping the Patches, PatchEncoder, and create_vit_feature_extractor functions)
149
+
150
+ def extract_features(image):
151
+ """
152
+ Extract multiple features from an image
153
+ """
154
+ color_hist = color_histogram(image)
155
+ color_mom = color_moments(image)
156
+ dom_color = dominant_color_descriptor(image)
157
+ ccv = color_coherence_vector(image)
158
+
159
+ return np.concatenate([color_hist, color_mom, dom_color, ccv])
160
+
161
+ from transformers import ViTFeatureExtractor, ViTModel
162
+ import torch
163
+ from tensorflow.keras import layers, models
164
+
165
+ def create_vit_feature_extractor(input_shape=(256, 256, 3), num_classes=None):
166
+ # Xây dựng mô hình ViT đã huấn luyện sẵn từ TensorFlow
167
+ inputs = layers.Input(shape=input_shape)
168
+
169
+ # Thêm lớp Lambda để tiền xử lý ảnh
170
+ x = Lambda(preprocess_input, output_shape=input_shape)(inputs) # Xử lý ảnh đầu vào
171
+
172
+ # Bạn có thể thay thế phần này bằng một mô hình ViT đã được huấn luyện sẵn.
173
+ # Dưới đây là ví dụ dùng EfficientNetB0 thay vì ViT.
174
+ # Tạo mô hình ViT hoặc sử dụng mô hình khác đã được huấn luyện sẵn
175
+ vit_model = EfficientNetB0(include_top=False, weights='imagenet', input_tensor=x)
176
+
177
+ # Trích xuất đặc trưng từ mô hình ViT
178
+ x = layers.GlobalAveragePooling2D()(vit_model.output)
179
+
180
+ if num_classes:
181
+ x = layers.Dense(num_classes, activation='softmax')(x) # Thêm lớp phân loại (nếu có)
182
+
183
+ return models.Model(inputs=inputs, outputs=x)
184
+
185
  def preprocess_image(image):
186
  """Preprocess the uploaded image"""
187
  # # Convert to RGB if needed
188
  # if image.mode != 'RGB':
189
  # image = image.convert('RGB')
190
+ # Convert to numpy array
 
191
  img_array = np.array(image)
192
 
193
  # # Convert to RGB if needed
 
215
 
216
  # Normalize
217
  img_array = img_array.astype('float32') / 255.0
218
+ image_features = extract_features(img_array)
219
+ vit_extractor = create_vit_feature_extractor()
220
+
221
+ # Trích xuất đặc trưng ViT từ các hình ảnh
222
+ image_vit = vit_extractor.predict(img_array) # Dự đoán cho tập train
223
+ image_combined = np.concatenate([image_features, image_vit], axis=1)
224
+ scaler = StandardScaler()
225
+ image_scaled = scaler.fit_transform(image_combined)
226
+ return image_scaled
227
 
228
  def get_top_predictions(prediction, class_names, top_k=5):
229
  """Get top k predictions with their probabilities"""