Spaces:
Runtime error
Runtime error
Added StreamingResponse
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from fastapi import FastAPI, File, UploadFile
|
2 |
-
from fastapi.responses import JSONResponse
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
from transformers import DetrImageProcessor, DetrForObjectDetection
|
5 |
from PIL import Image, ImageDraw
|
@@ -12,7 +12,7 @@ app = FastAPI()
|
|
12 |
# Add CORS middleware to allow communication with external clients
|
13 |
app.add_middleware(
|
14 |
CORSMiddleware,
|
15 |
-
allow_origins=["*"], # Change this to
|
16 |
allow_methods=["*"],
|
17 |
allow_headers=["*"],
|
18 |
)
|
@@ -35,7 +35,8 @@ def detect_accident(image):
|
|
35 |
for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
|
36 |
x_min, y_min, x_max, y_max = box
|
37 |
draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
|
38 |
-
|
|
|
39 |
|
40 |
return image
|
41 |
|
@@ -45,7 +46,7 @@ async def process_frame(file: UploadFile = File(...)):
|
|
45 |
try:
|
46 |
# Read and preprocess image
|
47 |
image = Image.open(io.BytesIO(await file.read()))
|
48 |
-
image = image.
|
49 |
|
50 |
# Detect accidents
|
51 |
processed_image = detect_accident(image)
|
@@ -55,10 +56,8 @@ async def process_frame(file: UploadFile = File(...)):
|
|
55 |
processed_image.save(img_byte_arr, format="JPEG")
|
56 |
img_byte_arr.seek(0)
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
media_type="image/jpeg"
|
61 |
-
)
|
62 |
except Exception as e:
|
63 |
return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
|
64 |
|
@@ -75,6 +74,7 @@ if __name__ == "__main__":
|
|
75 |
|
76 |
|
77 |
|
|
|
78 |
|
79 |
|
80 |
# import gradio as gr
|
|
|
1 |
from fastapi import FastAPI, File, UploadFile
|
2 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
from transformers import DetrImageProcessor, DetrForObjectDetection
|
5 |
from PIL import Image, ImageDraw
|
|
|
12 |
# Add CORS middleware to allow communication with external clients
|
13 |
app.add_middleware(
|
14 |
CORSMiddleware,
|
15 |
+
allow_origins=["*"], # Change this to specific domains in production
|
16 |
allow_methods=["*"],
|
17 |
allow_headers=["*"],
|
18 |
)
|
|
|
35 |
for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
|
36 |
x_min, y_min, x_max, y_max = box
|
37 |
draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
|
38 |
+
label_name = model.config.id2label[label.item()]
|
39 |
+
draw.text((x_min, y_min), f"{label_name}: {score:.2f}", fill="red")
|
40 |
|
41 |
return image
|
42 |
|
|
|
46 |
try:
|
47 |
# Read and preprocess image
|
48 |
image = Image.open(io.BytesIO(await file.read()))
|
49 |
+
image = image.convert("RGB") # Ensure compatibility with the model
|
50 |
|
51 |
# Detect accidents
|
52 |
processed_image = detect_accident(image)
|
|
|
56 |
processed_image.save(img_byte_arr, format="JPEG")
|
57 |
img_byte_arr.seek(0)
|
58 |
|
59 |
+
# Return the image as a streaming response
|
60 |
+
return StreamingResponse(img_byte_arr, media_type="image/jpeg")
|
|
|
|
|
61 |
except Exception as e:
|
62 |
return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
|
63 |
|
|
|
74 |
|
75 |
|
76 |
|
77 |
+
|
78 |
|
79 |
|
80 |
# import gradio as gr
|