wound_detect / predict.py
Ani14's picture
Update predict.py
7eb41d6 verified
raw
history blame
947 Bytes
from fastapi import FastAPI, File, UploadFile
import cv2
import numpy as np
from ultralytics import YOLO
from fastapi.responses import FileResponse
app = FastAPI()
yolo_model_path = 'best.pt'
yolo = YOLO(yolo_model_path)
def detect_wounds(image):
results = yolo(image)
boxes = results[0].boxes.xyxy.tolist()
return boxes
def draw_boxes(image, boxes):
for box in boxes:
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
return image
@app.post("/detect")
async def detect(image: UploadFile = File(...)):
image_bytes = await image.read()
image = np.frombuffer(image_bytes, np.uint8)
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
wound_boxes = detect_wounds(image)
image_with_boxes = draw_boxes(image, wound_boxes)
result_path = 'esult.jpg'
cv2.imwrite(result_path, image_with_boxes)
return FileResponse(result_path, media_type='image/jpeg')