eisei-ai-space / main.py
vumichien's picture
Update main.py
0830381 verified
import cv2
import numpy as np
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse, Response
import uvicorn
import logging
import time
import supervision as sv
from ultralytics import YOLO
app = FastAPI()
model = YOLO("models/best_v21.pt", task="detect")
def parse_detection(detections):
parsed_rows = []
for i in range(len(detections.xyxy)):
x_min = float(detections.xyxy[i][0])
y_min = float(detections.xyxy[i][1])
x_max = float(detections.xyxy[i][2])
y_max = float(detections.xyxy[i][3])
width = int(x_max - x_min)
height = int(y_max - y_min)
row = {
"x": int(y_min),
"y": int(x_min),
"width": width,
"height": height,
"class_id": ""
if detections.class_id is None
else int(detections.class_id[i]),
"confidence": ""
if detections.confidence is None
else float(detections.confidence[i]),
"tracker_id": ""
if detections.tracker_id is None
else int(detections.tracker_id[i]),
}
if hasattr(detections, "data"):
for key, value in detections.data.items():
if key == "class_name":
key = "class"
row[key] = (
str(value[i])
if hasattr(value, "__getitem__") and value.ndim != 0
else str(value)
)
parsed_rows.append(row)
return parsed_rows
def infer(image):
image_arr = np.frombuffer(image, np.uint8)
image = cv2.imdecode(image_arr, cv2.IMREAD_COLOR)
image = cv2.resize(image, (640, 640))
results = model(image, conf=0.6, iou=0.25, imgsz=640)[0]
width, height = results.orig_shape[1], results.orig_shape[0]
print(results.speed)
detections = sv.Detections.from_ultralytics(results)
parsed_rows = parse_detection(detections)
parsed_result = {'predictions': parsed_rows, 'image': {'width': width, 'height': height}}
return parsed_result
@app.post("/process-image/")
async def process_image(image: UploadFile = File(...)):
filename = image.filename
logging.info(f"Received process-image request for file: {filename}")
image_data = await image.read()
results = infer(image_data)
logging.info("Returning JSON results")
return JSONResponse(content=results)
@app.get("/")
def hello_world():
return 'Hello World from Detomo AI!'
if __name__ == "__main__":
uvicorn.run("main:app", port=8001, reload=True)