boostedhug commited on
Commit
3f13787
·
verified ·
1 Parent(s): 6e8e295

Added StreamingResponse

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
  from PIL import Image, ImageDraw
@@ -12,7 +12,7 @@ app = FastAPI()
12
  # Add CORS middleware to allow communication with external clients
13
  app.add_middleware(
14
  CORSMiddleware,
15
- allow_origins=["*"], # Change this to the specific domain in production
16
  allow_methods=["*"],
17
  allow_headers=["*"],
18
  )
@@ -35,7 +35,8 @@ def detect_accident(image):
35
  for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
36
  x_min, y_min, x_max, y_max = box
37
  draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
38
- draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
 
39
 
40
  return image
41
 
@@ -45,7 +46,7 @@ async def process_frame(file: UploadFile = File(...)):
45
  try:
46
  # Read and preprocess image
47
  image = Image.open(io.BytesIO(await file.read()))
48
- image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
49
 
50
  # Detect accidents
51
  processed_image = detect_accident(image)
@@ -55,10 +56,8 @@ async def process_frame(file: UploadFile = File(...)):
55
  processed_image.save(img_byte_arr, format="JPEG")
56
  img_byte_arr.seek(0)
57
 
58
- return JSONResponse(
59
- content={"status": "success", "message": "Frame processed successfully"},
60
- media_type="image/jpeg"
61
- )
62
  except Exception as e:
63
  return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
64
 
@@ -75,6 +74,7 @@ if __name__ == "__main__":
75
 
76
 
77
 
 
78
 
79
 
80
  # import gradio as gr
 
1
  from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import StreamingResponse, JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
  from PIL import Image, ImageDraw
 
12
  # Add CORS middleware to allow communication with external clients
13
  app.add_middleware(
14
  CORSMiddleware,
15
+ allow_origins=["*"], # Change this to specific domains in production
16
  allow_methods=["*"],
17
  allow_headers=["*"],
18
  )
 
35
  for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
36
  x_min, y_min, x_max, y_max = box
37
  draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
38
+ label_name = model.config.id2label[label.item()]
39
+ draw.text((x_min, y_min), f"{label_name}: {score:.2f}", fill="red")
40
 
41
  return image
42
 
 
46
  try:
47
  # Read and preprocess image
48
  image = Image.open(io.BytesIO(await file.read()))
49
+ image = image.convert("RGB") # Ensure compatibility with the model
50
 
51
  # Detect accidents
52
  processed_image = detect_accident(image)
 
56
  processed_image.save(img_byte_arr, format="JPEG")
57
  img_byte_arr.seek(0)
58
 
59
+ # Return the image as a streaming response
60
+ return StreamingResponse(img_byte_arr, media_type="image/jpeg")
 
 
61
  except Exception as e:
62
  return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
63
 
 
74
 
75
 
76
 
77
+
78
 
79
 
80
  # import gradio as gr