Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,42 +8,175 @@ from tensorflow.keras.applications import EfficientNetB0
|
|
8 |
from tensorflow.keras.applications.efficientnet import preprocess_input
|
9 |
import joblib
|
10 |
import io
|
|
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
)
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def load_model_and_scaler():
|
48 |
"""Load the trained model and scaler"""
|
49 |
try:
|
@@ -60,34 +193,34 @@ def color_histogram(image, bins=16):
|
|
60 |
hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
|
61 |
hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
|
62 |
hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
|
63 |
-
|
64 |
hist_r = hist_r / (np.sum(hist_r) + 1e-7)
|
65 |
hist_g = hist_g / (np.sum(hist_g) + 1e-7)
|
66 |
hist_b = hist_b / (np.sum(hist_b) + 1e-7)
|
67 |
-
|
68 |
return np.concatenate([hist_r, hist_g, hist_b])
|
69 |
|
70 |
def color_moments(image):
|
71 |
"""Calculate color moments features"""
|
72 |
img = image.astype(np.float32) / 255.0
|
73 |
moments = []
|
74 |
-
|
75 |
for i in range(3):
|
76 |
channel = img[:,:,i]
|
77 |
mean = np.mean(channel)
|
78 |
std = np.std(channel) + 1e-7
|
79 |
skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
|
80 |
moments.extend([mean, std, skewness])
|
81 |
-
|
82 |
return np.array(moments)
|
83 |
|
84 |
def dominant_color_descriptor(image, k=3):
|
85 |
"""Calculate dominant color descriptor"""
|
86 |
pixels = image.reshape(-1, 3).astype(np.float32)
|
87 |
-
|
88 |
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
|
89 |
flags = cv2.KMEANS_RANDOM_CENTERS
|
90 |
-
|
91 |
try:
|
92 |
_, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, flags)
|
93 |
unique, counts = np.unique(labels, return_counts=True)
|
@@ -100,16 +233,16 @@ def color_coherence_vector(image, k=3):
|
|
100 |
"""Calculate color coherence vector"""
|
101 |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
102 |
gray = np.uint8(gray)
|
103 |
-
|
104 |
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
105 |
num_labels, labels = cv2.connectedComponents(binary)
|
106 |
-
|
107 |
ccv = []
|
108 |
for i in range(1, min(k+1, num_labels)):
|
109 |
region_mask = (labels == i)
|
110 |
total_pixels = np.sum(region_mask)
|
111 |
ccv.extend([total_pixels, total_pixels])
|
112 |
-
|
113 |
ccv.extend([0] * (2 * k - len(ccv)))
|
114 |
return np.array(ccv[:2*k])
|
115 |
|
@@ -119,13 +252,13 @@ def create_vit_feature_extractor():
|
|
119 |
input_shape = (256, 256, 3)
|
120 |
inputs = layers.Input(shape=input_shape)
|
121 |
x = layers.Lambda(preprocess_input)(inputs)
|
122 |
-
|
123 |
base_model = EfficientNetB0(
|
124 |
include_top=False,
|
125 |
weights='imagenet',
|
126 |
input_tensor=x
|
127 |
)
|
128 |
-
|
129 |
x = layers.GlobalAveragePooling2D()(base_model.output)
|
130 |
return models.Model(inputs=inputs, outputs=x)
|
131 |
|
@@ -136,21 +269,21 @@ def extract_features(image):
|
|
136 |
moment_features = color_moments(image)
|
137 |
dominant_features = dominant_color_descriptor(image)
|
138 |
ccv_features = color_coherence_vector(image)
|
139 |
-
|
140 |
traditional_features = np.concatenate([
|
141 |
hist_features,
|
142 |
moment_features,
|
143 |
dominant_features,
|
144 |
ccv_features
|
145 |
])
|
146 |
-
|
147 |
# Deep features using ViT
|
148 |
feature_extractor = create_vit_feature_extractor()
|
149 |
vit_features = feature_extractor.predict(
|
150 |
np.expand_dims(image, axis=0),
|
151 |
verbose=0
|
152 |
)
|
153 |
-
|
154 |
# Combine all features
|
155 |
return np.concatenate([traditional_features, vit_features.flatten()])
|
156 |
|
@@ -159,100 +292,25 @@ def preprocess_image(image, scaler):
|
|
159 |
# Convert to RGB if needed
|
160 |
if image.mode != 'RGB':
|
161 |
image = image.convert('RGB')
|
162 |
-
|
163 |
# Convert to numpy array and resize
|
164 |
img_array = np.array(image)
|
165 |
img_array = cv2.resize(img_array, (256, 256))
|
166 |
img_array = img_array.astype('float32') / 255.0
|
167 |
-
|
168 |
# Extract all features
|
169 |
features = extract_features(img_array)
|
170 |
-
|
171 |
# Scale features using the provided scaler
|
172 |
scaled_features = scaler.transform(features.reshape(1, -1))
|
173 |
-
|
174 |
-
return scaled_features
|
175 |
|
176 |
-
|
177 |
-
"""Get top k predictions with their probabilities"""
|
178 |
-
top_indices = prediction.argsort()[0][-top_k:][::-1]
|
179 |
-
return [
|
180 |
-
(class_names[i], float(prediction[0][i]) * 100)
|
181 |
-
for i in top_indices
|
182 |
-
]
|
183 |
|
184 |
-
def
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
model, scaler = load_model_and_scaler()
|
190 |
-
if model is None or scaler is None:
|
191 |
-
st.error("Failed to load model or scaler. Please ensure both files exist.")
|
192 |
-
return
|
193 |
-
|
194 |
-
# Initialize session state
|
195 |
-
if 'predictions' not in st.session_state:
|
196 |
-
st.session_state.predictions = None
|
197 |
-
|
198 |
-
col1, col2 = st.columns(2)
|
199 |
-
|
200 |
-
with col1:
|
201 |
-
st.subheader("Upload Image")
|
202 |
-
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
203 |
-
|
204 |
-
if uploaded_file is not None:
|
205 |
-
try:
|
206 |
-
image = Image.open(uploaded_file)
|
207 |
-
st.image(image, caption="Uploaded Image", use_column_width=True)
|
208 |
-
|
209 |
-
with st.spinner('Analyzing image...'):
|
210 |
-
processed_image = preprocess_image(image, scaler)
|
211 |
-
prediction = model.predict(processed_image, verbose=0)
|
212 |
-
|
213 |
-
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
|
214 |
-
st.session_state.predictions = get_top_predictions(prediction, class_names)
|
215 |
-
|
216 |
-
except Exception as e:
|
217 |
-
st.error(f"Error processing image: {str(e)}")
|
218 |
-
|
219 |
-
with col2:
|
220 |
-
st.subheader("Prediction Results")
|
221 |
-
if st.session_state.predictions:
|
222 |
-
# Display main prediction
|
223 |
-
top_class, top_confidence = st.session_state.predictions[0]
|
224 |
-
st.markdown(
|
225 |
-
f"""
|
226 |
-
<div class='prediction-card'>
|
227 |
-
<h3>Primary Prediction: Grade {top_class}</h3>
|
228 |
-
<h3>Confidence: {top_confidence:.2f}%</h3>
|
229 |
-
</div>
|
230 |
-
""",
|
231 |
-
unsafe_allow_html=True
|
232 |
-
)
|
233 |
-
|
234 |
-
# Display confidence bar
|
235 |
-
st.progress(top_confidence / 100)
|
236 |
-
|
237 |
-
# Display top 5 predictions
|
238 |
-
st.markdown("### Top 5 Predictions")
|
239 |
-
st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
|
240 |
-
|
241 |
-
for class_name, confidence in st.session_state.predictions:
|
242 |
-
cols = st.columns([2, 6, 2])
|
243 |
-
with cols[0]:
|
244 |
-
st.write(f"Grade {class_name}")
|
245 |
-
with cols[1]:
|
246 |
-
st.progress(confidence / 100)
|
247 |
-
with cols[2]:
|
248 |
-
st.write(f"{confidence:.2f}%")
|
249 |
-
|
250 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
251 |
-
else:
|
252 |
-
st.info("Upload an image to see the predictions")
|
253 |
-
|
254 |
-
st.markdown("---")
|
255 |
-
st.markdown("Made with ❤️ using Streamlit")
|
256 |
|
257 |
if __name__ == "__main__":
|
258 |
-
main()
|
|
|
8 |
from tensorflow.keras.applications.efficientnet import preprocess_input
|
9 |
import joblib
|
10 |
import io
|
11 |
+
import os
|
12 |
|
13 |
+
# Add Cloudinary import
|
14 |
+
import cloudinary
|
15 |
+
import cloudinary.uploader
|
16 |
+
from cloudinary.utils import cloudinary_url
|
17 |
+
|
18 |
+
# Cloudinary Configuration
|
19 |
+
cloudinary.config(
|
20 |
+
cloud_name = os.getenv("CLOUD"),
|
21 |
+
api_key = os.getenv("API"),
|
22 |
+
api_secret = os.getenv("SECRET"),
|
23 |
+
secure=True
|
24 |
)
|
25 |
|
26 |
+
def upload_to_cloudinary(file_path, label):
|
27 |
+
"""
|
28 |
+
Upload file to Cloudinary with specified label as folder
|
29 |
+
"""
|
30 |
+
try:
|
31 |
+
# Upload to Cloudinary
|
32 |
+
upload_result = cloudinary.uploader.upload(
|
33 |
+
file_path,
|
34 |
+
folder=label,
|
35 |
+
public_id=f"{label}_{os.path.basename(file_path)}"
|
36 |
+
)
|
37 |
+
|
38 |
+
# Generate optimized URLs
|
39 |
+
optimize_url, _ = cloudinary_url(
|
40 |
+
upload_result['public_id'],
|
41 |
+
fetch_format="auto",
|
42 |
+
quality="auto"
|
43 |
+
)
|
44 |
+
|
45 |
+
auto_crop_url, _ = cloudinary_url(
|
46 |
+
upload_result['public_id'],
|
47 |
+
width=500,
|
48 |
+
height=500,
|
49 |
+
crop="auto",
|
50 |
+
gravity="auto"
|
51 |
+
)
|
52 |
+
|
53 |
+
return {
|
54 |
+
"upload_result": upload_result,
|
55 |
+
"optimize_url": optimize_url,
|
56 |
+
"auto_crop_url": auto_crop_url
|
57 |
+
}
|
58 |
+
|
59 |
+
except Exception as e:
|
60 |
+
return f"Error uploading to Cloudinary: {str(e)}"
|
61 |
+
|
62 |
+
def main():
|
63 |
+
st.title("🨨 Phân loại đá")
|
64 |
+
st.write("Tải lên hình ảnh của một viên đá để phân loại loại của nó.")
|
65 |
+
|
66 |
+
# Load model and scaler
|
67 |
+
model, scaler = load_model_and_scaler()
|
68 |
+
if model is None or scaler is None:
|
69 |
+
st.error("Không thể tải mô hình hoặc bộ chuẩn hóa. Vui lòng đảm bảo rằng cả hai tệp đều tồn tại.")
|
70 |
+
return
|
71 |
+
|
72 |
+
# Initialize session state
|
73 |
+
if 'predictions' not in st.session_state:
|
74 |
+
st.session_state.predictions = None
|
75 |
+
if 'uploaded_image' not in st.session_state:
|
76 |
+
st.session_state.uploaded_image = None
|
77 |
+
|
78 |
+
col1, col2 = st.columns(2)
|
79 |
+
|
80 |
+
with col1:
|
81 |
+
st.subheader("Tải lên Hình ảnh")
|
82 |
+
uploaded_file = st.file_uploader("Chọn hình ảnh...", type=["jpg", "jpeg", "png"])
|
83 |
+
|
84 |
+
if uploaded_file is not None:
|
85 |
+
try:
|
86 |
+
image = Image.open(uploaded_file)
|
87 |
+
st.image(image, caption="Hình ảnh đã tải lên", use_column_width=True)
|
88 |
+
st.session_state.uploaded_image = image
|
89 |
+
|
90 |
+
with st.spinner('Đang phân tích hình ảnh...'):
|
91 |
+
processed_image = preprocess_image(image, scaler)
|
92 |
+
prediction = model.predict(processed_image, verbose=0)
|
93 |
+
|
94 |
+
class_names = ['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7']
|
95 |
+
st.session_state.predictions = get_top_predictions(prediction, class_names)
|
96 |
+
|
97 |
+
except Exception as e:
|
98 |
+
st.error(f"Lỗi khi xử lý hình ảnh: {str(e)}")
|
99 |
+
|
100 |
+
with col2:
|
101 |
+
st.subheader("Kết quả Dự đoán")
|
102 |
+
if st.session_state.predictions:
|
103 |
+
# Display main prediction
|
104 |
+
top_class, top_confidence = st.session_state.predictions[0]
|
105 |
+
st.markdown(
|
106 |
+
f"""
|
107 |
+
<div class='prediction-card'>
|
108 |
+
<h3>Dự đoán chính: Màu {top_class}</h3>
|
109 |
+
<h3>Độ tin cậy: {top_confidence:.2f}%</h3>
|
110 |
+
</div>
|
111 |
+
""",
|
112 |
+
unsafe_allow_html=True
|
113 |
+
)
|
114 |
+
|
115 |
+
# Display confidence bar
|
116 |
+
st.progress(top_confidence / 100)
|
117 |
+
|
118 |
+
# Display top 5 predictions
|
119 |
+
st.markdown("### 5 Dự đoán hàng đầu")
|
120 |
+
st.markdown("<div class='top-predictions'>", unsafe_allow_html=True)
|
121 |
+
|
122 |
+
for class_name, confidence in st.session_state.predictions:
|
123 |
+
st.markdown(
|
124 |
+
f"**Màu {class_name}: Độ tin cậy {confidence:.2f}%**"
|
125 |
+
)
|
126 |
+
st.progress(confidence / 100)
|
127 |
+
|
128 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
129 |
+
|
130 |
+
# User Confirmation Section
|
131 |
+
st.markdown("### Xác nhận độ chính xác của mô hình")
|
132 |
+
st.write("Giúp chúng tôi cải thiện mô hình bằng cách xác nhận độ chính xác của dự đoán.")
|
133 |
+
|
134 |
+
# Accuracy Radio Button
|
135 |
+
accuracy_option = st.radio(
|
136 |
+
"Dự đoán có chính xác không?",
|
137 |
+
["Chọn", "Chính xác", "Không chính xác"],
|
138 |
+
index=0
|
139 |
+
)
|
140 |
+
|
141 |
+
if accuracy_option == "Không chính xác":
|
142 |
+
# Input for correct grade
|
143 |
+
correct_grade = st.selectbox(
|
144 |
+
"Chọn màu đá đúng:",
|
145 |
+
['10', '6.5', '7', '7.5', '8', '8.5', '9', '9.2', '9.5', '9.7'],
|
146 |
+
index=None,
|
147 |
+
placeholder="Chọn màu đúng"
|
148 |
+
)
|
149 |
+
|
150 |
+
# Upload button
|
151 |
+
if st.button("Tải lên Hình ảnh để sửa chữa"):
|
152 |
+
if correct_grade and st.session_state.uploaded_image:
|
153 |
+
# Save the image temporarily
|
154 |
+
temp_image_path = f"temp_image_{hash(uploaded_file.name)}.png"
|
155 |
+
st.session_state.uploaded_image.save(temp_image_path)
|
156 |
+
|
157 |
+
try:
|
158 |
+
# Upload to Cloudinary
|
159 |
+
cloudinary_result = upload_to_cloudinary(temp_image_path, correct_grade)
|
160 |
+
|
161 |
+
if isinstance(cloudinary_result, dict):
|
162 |
+
st.success(f"Hình ảnh đã được tải lên thành công cho màu {correct_grade}")
|
163 |
+
st.write(f"URL công khai: {cloudinary_result['upload_result']['secure_url']}")
|
164 |
+
else:
|
165 |
+
st.error(cloudinary_result)
|
166 |
+
|
167 |
+
# Clean up temporary file
|
168 |
+
os.remove(temp_image_path)
|
169 |
+
|
170 |
+
except Exception as e:
|
171 |
+
st.error(f"Tải lên thất bại: {str(e)}")
|
172 |
+
else:
|
173 |
+
st.warning("Vui lòng chọn màu đúng trước khi tải lên.")
|
174 |
+
else:
|
175 |
+
st.info("Tải lên hình ảnh để xem các dự đoán.")
|
176 |
+
|
177 |
+
st.markdown("---")
|
178 |
+
st.markdown("Tạo bởi ❤️ với Streamlit")
|
179 |
+
|
180 |
def load_model_and_scaler():
|
181 |
"""Load the trained model and scaler"""
|
182 |
try:
|
|
|
193 |
hist_r = cv2.calcHist([image], [0], None, [bins], [0, 256]).flatten()
|
194 |
hist_g = cv2.calcHist([image], [1], None, [bins], [0, 256]).flatten()
|
195 |
hist_b = cv2.calcHist([image], [2], None, [bins], [0, 256]).flatten()
|
196 |
+
|
197 |
hist_r = hist_r / (np.sum(hist_r) + 1e-7)
|
198 |
hist_g = hist_g / (np.sum(hist_g) + 1e-7)
|
199 |
hist_b = hist_b / (np.sum(hist_b) + 1e-7)
|
200 |
+
|
201 |
return np.concatenate([hist_r, hist_g, hist_b])
|
202 |
|
203 |
def color_moments(image):
|
204 |
"""Calculate color moments features"""
|
205 |
img = image.astype(np.float32) / 255.0
|
206 |
moments = []
|
207 |
+
|
208 |
for i in range(3):
|
209 |
channel = img[:,:,i]
|
210 |
mean = np.mean(channel)
|
211 |
std = np.std(channel) + 1e-7
|
212 |
skewness = np.mean(((channel - mean) / std) ** 3) if std != 0 else 0
|
213 |
moments.extend([mean, std, skewness])
|
214 |
+
|
215 |
return np.array(moments)
|
216 |
|
217 |
def dominant_color_descriptor(image, k=3):
|
218 |
"""Calculate dominant color descriptor"""
|
219 |
pixels = image.reshape(-1, 3).astype(np.float32)
|
220 |
+
|
221 |
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
|
222 |
flags = cv2.KMEANS_RANDOM_CENTERS
|
223 |
+
|
224 |
try:
|
225 |
_, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, flags)
|
226 |
unique, counts = np.unique(labels, return_counts=True)
|
|
|
233 |
"""Calculate color coherence vector"""
|
234 |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
235 |
gray = np.uint8(gray)
|
236 |
+
|
237 |
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
238 |
num_labels, labels = cv2.connectedComponents(binary)
|
239 |
+
|
240 |
ccv = []
|
241 |
for i in range(1, min(k+1, num_labels)):
|
242 |
region_mask = (labels == i)
|
243 |
total_pixels = np.sum(region_mask)
|
244 |
ccv.extend([total_pixels, total_pixels])
|
245 |
+
|
246 |
ccv.extend([0] * (2 * k - len(ccv)))
|
247 |
return np.array(ccv[:2*k])
|
248 |
|
|
|
252 |
input_shape = (256, 256, 3)
|
253 |
inputs = layers.Input(shape=input_shape)
|
254 |
x = layers.Lambda(preprocess_input)(inputs)
|
255 |
+
|
256 |
base_model = EfficientNetB0(
|
257 |
include_top=False,
|
258 |
weights='imagenet',
|
259 |
input_tensor=x
|
260 |
)
|
261 |
+
|
262 |
x = layers.GlobalAveragePooling2D()(base_model.output)
|
263 |
return models.Model(inputs=inputs, outputs=x)
|
264 |
|
|
|
269 |
moment_features = color_moments(image)
|
270 |
dominant_features = dominant_color_descriptor(image)
|
271 |
ccv_features = color_coherence_vector(image)
|
272 |
+
|
273 |
traditional_features = np.concatenate([
|
274 |
hist_features,
|
275 |
moment_features,
|
276 |
dominant_features,
|
277 |
ccv_features
|
278 |
])
|
279 |
+
|
280 |
# Deep features using ViT
|
281 |
feature_extractor = create_vit_feature_extractor()
|
282 |
vit_features = feature_extractor.predict(
|
283 |
np.expand_dims(image, axis=0),
|
284 |
verbose=0
|
285 |
)
|
286 |
+
|
287 |
# Combine all features
|
288 |
return np.concatenate([traditional_features, vit_features.flatten()])
|
289 |
|
|
|
292 |
# Convert to RGB if needed
|
293 |
if image.mode != 'RGB':
|
294 |
image = image.convert('RGB')
|
295 |
+
|
296 |
# Convert to numpy array and resize
|
297 |
img_array = np.array(image)
|
298 |
img_array = cv2.resize(img_array, (256, 256))
|
299 |
img_array = img_array.astype('float32') / 255.0
|
300 |
+
|
301 |
# Extract all features
|
302 |
features = extract_features(img_array)
|
303 |
+
|
304 |
# Scale features using the provided scaler
|
305 |
scaled_features = scaler.transform(features.reshape(1, -1))
|
|
|
|
|
306 |
|
307 |
+
return scaled_features
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
|
309 |
+
def get_top_predictions(prediction, class_names):
|
310 |
+
# Extract the top 5 predictions with confidence values
|
311 |
+
probabilities = tf.nn.softmax(prediction[0]).numpy()
|
312 |
+
top_indices = np.argsort(probabilities)[-5:][::-1]
|
313 |
+
return [(class_names[i], probabilities[i] * 100) for i in top_indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
if __name__ == "__main__":
|
316 |
+
main()
|