xcurvnubaim commited on
Commit
5034720
·
1 Parent(s): d9c0ea9

feat: add prediction with full image

Browse files
Files changed (1) hide show
  1. main.py +14 -1
main.py CHANGED
@@ -74,10 +74,21 @@ def animal_detect_and_classify(img_path):
74
  # Make predictions using the classification model
75
  prediction = classification_model.predict(inp_array)
76
  # Map predictions to labels
77
- threshold = 0.75
78
  predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= threshold else "animal" for pred in prediction]
79
  print(predicted_labels)
80
  combined_results.append(((x1, y1, x2, y2), predicted_labels))
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  return combined_results
83
 
@@ -101,6 +112,8 @@ def plot_detected_rectangles(image, detections, output_path):
101
 
102
  # Iterate over each detected rectangle and its corresponding class name
103
  for rectangle, class_names in detections:
 
 
104
  # Extract the coordinates of the rectangle
105
  x1, y1, x2, y2 = rectangle
106
 
 
74
  # Make predictions using the classification model
75
  prediction = classification_model.predict(inp_array)
76
  # Map predictions to labels
77
+ threshold = 0.66
78
  predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= threshold else "animal" for pred in prediction]
79
  print(predicted_labels)
80
  combined_results.append(((x1, y1, x2, y2), predicted_labels))
81
+ y2, x2, _ = img.shape
82
+ detect_img = img[0:y2, 0:x2]
83
+ detect_img = cv2.cvtColor(detect_img, cv2.COLOR_BGR2RGB)
84
+ detect_img = cv2.resize(detect_img, (224, 224))
85
+ inp_array = np.array(detect_img)
86
+ inp_array = inp_array.reshape((-1, 224, 224, 3))
87
+ inp_array = tf.keras.applications.efficientnet.preprocess_input(inp_array)
88
+ prediction = classification_model.predict(inp_array)
89
+ threshold = 0.66
90
+ predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= threshold else "unknown" for pred in prediction]
91
+ combined_results.append(((0, 0, x2, y2), predicted_labels))
92
 
93
  return combined_results
94
 
 
112
 
113
  # Iterate over each detected rectangle and its corresponding class name
114
  for rectangle, class_names in detections:
115
+ if class_names[0] == "unknown":
116
+ continue
117
  # Extract the coordinates of the rectangle
118
  x1, y1, x2, y2 = rectangle
119