Spaces:
Runtime error
Runtime error
xcurvnubaim
commited on
Commit
·
5034720
1
Parent(s):
d9c0ea9
feat: add prediction with full image
Browse files
main.py
CHANGED
@@ -74,10 +74,21 @@ def animal_detect_and_classify(img_path):
|
|
74 |
# Make predictions using the classification model
|
75 |
prediction = classification_model.predict(inp_array)
|
76 |
# Map predictions to labels
|
77 |
-
threshold = 0.
|
78 |
predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= threshold else "animal" for pred in prediction]
|
79 |
print(predicted_labels)
|
80 |
combined_results.append(((x1, y1, x2, y2), predicted_labels))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
return combined_results
|
83 |
|
@@ -101,6 +112,8 @@ def plot_detected_rectangles(image, detections, output_path):
|
|
101 |
|
102 |
# Iterate over each detected rectangle and its corresponding class name
|
103 |
for rectangle, class_names in detections:
|
|
|
|
|
104 |
# Extract the coordinates of the rectangle
|
105 |
x1, y1, x2, y2 = rectangle
|
106 |
|
|
|
74 |
# Make predictions using the classification model
|
75 |
prediction = classification_model.predict(inp_array)
|
76 |
# Map predictions to labels
|
77 |
+
threshold = 0.66
|
78 |
predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= threshold else "animal" for pred in prediction]
|
79 |
print(predicted_labels)
|
80 |
combined_results.append(((x1, y1, x2, y2), predicted_labels))
|
81 |
+
y2, x2, _ = img.shape
|
82 |
+
detect_img = img[0:y2, 0:x2]
|
83 |
+
detect_img = cv2.cvtColor(detect_img, cv2.COLOR_BGR2RGB)
|
84 |
+
detect_img = cv2.resize(detect_img, (224, 224))
|
85 |
+
inp_array = np.array(detect_img)
|
86 |
+
inp_array = inp_array.reshape((-1, 224, 224, 3))
|
87 |
+
inp_array = tf.keras.applications.efficientnet.preprocess_input(inp_array)
|
88 |
+
prediction = classification_model.predict(inp_array)
|
89 |
+
threshold = 0.66
|
90 |
+
predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= threshold else "unknown" for pred in prediction]
|
91 |
+
combined_results.append(((0, 0, x2, y2), predicted_labels))
|
92 |
|
93 |
return combined_results
|
94 |
|
|
|
112 |
|
113 |
# Iterate over each detected rectangle and its corresponding class name
|
114 |
for rectangle, class_names in detections:
|
115 |
+
if class_names[0] == "unknown":
|
116 |
+
continue
|
117 |
# Extract the coordinates of the rectangle
|
118 |
x1, y1, x2, y2 = rectangle
|
119 |
|