Browse files
@@ -4,7 +4,20 @@ import numpy as np
4 |
import cv2
5 |
from PIL import Image
6 |
import io
7 |
8 |
# Set page config
9 |
10 |
page_title="Stone Classification",
@@ -54,15 +67,127 @@ st.markdown("""
54 |
55 |
def load_model():
56 |
"""Load the trained model"""
57 |
return tf.keras.models.load_model('
58 |
59 |
def preprocess_image(image):
60 |
"""Preprocess the uploaded image"""
61 |
# # Convert to RGB if needed
62 |
# if image.mode != 'RGB':
63 |
# image = image.convert('RGB')
64 |
65 |
# Convert to numpy array
66 |
img_array = np.array(image)
67 |
68 |
# # Convert to RGB if needed
@@ -90,8 +215,15 @@ def preprocess_image(image):
90 |
91 |
# Normalize
92 |
img_array = img_array.astype('float32') / 255.0
93 |
94 |
95 |
96 |
def get_top_predictions(prediction, class_names, top_k=5):
97 |
"""Get top k predictions with their probabilities"""
4 |
import cv2
5 |
from PIL import Image
6 |
import io
7 |
import os
8 |
import cv2
9 |
import numpy as np
10 |
from tensorflow import keras
11 |
from tensorflow.keras import layers, models
12 |
from sklearn.preprocessing import StandardScaler
13 |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
14 |
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
15 |
import matplotlib.pyplot as plt
16 |
import random
17 |
from tensorflow.keras import layers, models
18 |
from tensorflow.keras.applications import EfficientNetB0
19 |
from tensorflow.keras.applications.efficientnet import preprocess_input
20 |
from tensorflow.keras.layers import Lambda # Đảm bảo nhập Lambda từ tensorflow.keras.layers
21 |
# Set page config
22 |
23 |
page_title="Stone Classification",
67 |
68 |
def load_model():
69 |
"""Load the trained model"""
70 |
return tf.keras.models.load_model('mlp_model.h5')
71 |
72 |
def color_histogram(image, bins=16):
73 |
# (Previous implementation remains the same)
74 |
hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
75 |
hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
76 |
hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
77 |
78 |
hist_r = hist_r / np.sum(hist_r)
79 |
hist_g = hist_g / np.sum(hist_g)
80 |
hist_b = hist_b / np.sum(hist_b)
81 |
82 |
return np.concatenate([hist_r, hist_g, hist_b])
83 |
84 |
def color_moments(image):
85 |
# (Previous implementation remains the same)
86 |
img = image.astype(np.float32) / 255.0
87 |
88 |
moments = []
89 |
for i in range(3): # For each color channel
90 |
channel = img[:,:,i]
91 |
92 |
mean = np.mean(channel)
93 |
std = np.std(channel)
94 |
skewness = np.mean(((channel - mean) / std) ** 3)
95 |
96 |
moments.extend([mean, std, skewness])
97 |
98 |
return np.array(moments)
99 |
100 |
def dominant_color_descriptor(image, k=3):
101 |
# (Previous implementation remains the same)
102 |
pixels = image.reshape(-1, 3)
103 |
104 |
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
105 |
106 |
107 |
108 |
_, labels, centers = cv2.kmeans(pixels.astype(np.float32), k, None, criteria, 10, flags)
109 |
110 |
unique, counts = np.unique(labels, return_counts=True)
111 |
percentages = counts / len(labels)
112 |
113 |
dominant_colors = centers.flatten()
114 |
color_percentages = percentages
115 |
116 |
return np.concatenate([dominant_colors, color_percentages])
117 |
118 |
return np.zeros(2 * k)
119 |
120 |
def color_coherence_vector(image, k=3):
121 |
# Convert to grayscale
122 |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
123 |
124 |
# Convert the grayscale image to 8-bit format before applying threshold
125 |
gray = np.uint8(gray)
126 |
127 |
# Apply Otsu's thresholding method
128 |
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
129 |
130 |
# Perform connected components analysis
131 |
num_labels, labels = cv2.connectedComponents(binary)
132 |
133 |
ccv = []
134 |
for i in range(1, min(k+1, num_labels)):
135 |
region_mask = (labels == i)
136 |
total_pixels = np.sum(region_mask)
137 |
coherent_pixels = total_pixels
138 |
139 |
ccv.extend([coherent_pixels, total_pixels])
140 |
141 |
while len(ccv) < 2 * k:
142 |
143 |
144 |
return np.array(ccv)
145 |
146 |
147 |
# ViT and Feature Extraction Functions (from previous implementation)
148 |
# (Keeping the Patches, PatchEncoder, and create_vit_feature_extractor functions)
149 |
150 |
def extract_features(image):
151 |
152 |
Extract multiple features from an image
153 |
154 |
color_hist = color_histogram(image)
155 |
color_mom = color_moments(image)
156 |
dom_color = dominant_color_descriptor(image)
157 |
ccv = color_coherence_vector(image)
158 |
159 |
return np.concatenate([color_hist, color_mom, dom_color, ccv])
160 |
161 |
from transformers import ViTFeatureExtractor, ViTModel
162 |
import torch
163 |
from tensorflow.keras import layers, models
164 |
165 |
def create_vit_feature_extractor(input_shape=(256, 256, 3), num_classes=None):
166 |
# Xây dựng mô hình ViT đã huấn luyện sẵn từ TensorFlow
167 |
inputs = layers.Input(shape=input_shape)
168 |
169 |
# Thêm lớp Lambda để tiền xử lý ảnh
170 |
x = Lambda(preprocess_input, output_shape=input_shape)(inputs) # Xử lý ảnh đầu vào
171 |
172 |
# Bạn có thể thay thế phần này bằng một mô hình ViT đã được huấn luyện sẵn.
173 |
# Dưới đây là ví dụ dùng EfficientNetB0 thay vì ViT.
174 |
# Tạo mô hình ViT hoặc sử dụng mô hình khác đã được huấn luyện sẵn
175 |
vit_model = EfficientNetB0(include_top=False, weights='imagenet', input_tensor=x)
176 |
177 |
# Trích xuất đặc trưng từ mô hình ViT
178 |
x = layers.GlobalAveragePooling2D()(vit_model.output)
179 |
180 |
if num_classes:
181 |
x = layers.Dense(num_classes, activation='softmax')(x) # Thêm lớp phân loại (nếu có)
182 |
183 |
return models.Model(inputs=inputs, outputs=x)
184 |
185 |
def preprocess_image(image):
186 |
"""Preprocess the uploaded image"""
187 |
# # Convert to RGB if needed
188 |
# if image.mode != 'RGB':
189 |
# image = image.convert('RGB')
190 |
# Convert to numpy array
191 |
img_array = np.array(image)
192 |
193 |
# # Convert to RGB if needed
215 |
216 |
# Normalize
217 |
img_array = img_array.astype('float32') / 255.0
218 |
image_features = extract_features(img_array)
219 |
vit_extractor = create_vit_feature_extractor()
220 |
221 |
# Trích xuất đặc trưng ViT từ các hình ảnh
222 |
image_vit = vit_extractor.predict(img_array) # Dự đoán cho tập train
223 |
image_combined = np.concatenate([image_features, image_vit], axis=1)
224 |
scaler = StandardScaler()
225 |
image_scaled = scaler.fit_transform(image_combined)
226 |
return image_scaled
227 |
228 |
def get_top_predictions(prediction, class_names, top_k=5):
229 |
"""Get top k predictions with their probabilities"""