xcurvnubaim commited on
Commit
349b9d7
·
1 Parent(s): d9a9d0e

feat: add class name in predict v2

Browse files
Files changed (1) hide show
  1. main.py +3 -8
main.py CHANGED
@@ -109,13 +109,6 @@ def plot_detected_rectangles(image, detections, output_path):
109
  cv2.imwrite(output_path, img_with_rectangles)
110
 
111
 
112
- # # Call the animal_detect_and_classify function to get detections
113
- # detections = animal_detect_and_classify('/content/cat_tiger.jpg')
114
-
115
- # # Plot the detected rectangles with their corresponding class names
116
- # plot_detected_rectangles(cv2.imread('/content/cat_tiger.jpg'), detections)
117
-
118
-
119
  @app.post("/predict/v2")
120
  async def predict_v2(file: UploadFile = File(...)):
121
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_")
@@ -124,10 +117,12 @@ async def predict_v2(file: UploadFile = File(...)):
124
  image = Image.open(BytesIO(contents))
125
  image.save("input/" + filename)
126
  detections = animal_detect_and_classify("input/" + filename)
 
127
  plot_detected_rectangles(cv2.imread("input/" + filename), detections, "output/" + filename)
128
  return {
129
  "message": "Detection and classification completed successfully",
130
- "data": "output/" + filename
 
131
  }
132
 
133
  @app.get("/image/")
 
109
  cv2.imwrite(output_path, img_with_rectangles)
110
 
111
 
 
 
 
 
 
 
 
112
  @app.post("/predict/v2")
113
  async def predict_v2(file: UploadFile = File(...)):
114
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_")
 
117
  image = Image.open(BytesIO(contents))
118
  image.save("input/" + filename)
119
  detections = animal_detect_and_classify("input/" + filename)
120
+ class_names = [class_name for _, class_name in detections]
121
  plot_detected_rectangles(cv2.imread("input/" + filename), detections, "output/" + filename)
122
  return {
123
  "message": "Detection and classification completed successfully",
124
+ "data": "output/" + filename,
125
+ "class_names": class_names
126
  }
127
 
128
  @app.get("/image/")