boostedhug commited on
Commit
7a62d82
·
verified ·
1 Parent(s): fb2938d
Files changed (1) hide show
  1. app.py +133 -47
app.py CHANGED
@@ -1,68 +1,154 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
  from PIL import Image, ImageDraw
6
- import io
7
  import torch
 
 
8
 
9
- # Initialize FastAPI app
10
- app = FastAPI()
11
-
12
- # Add CORS middleware to allow communication with external clients
13
- app.add_middleware(
14
- CORSMiddleware,
15
- allow_origins=["*"], # Change this to the specific domain in production
16
- allow_methods=["*"],
17
- allow_headers=["*"],
18
- )
19
-
20
- # Load the model and processor
21
  model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
22
  processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
23
 
 
24
  def detect_accident(image):
25
- """Runs accident detection on the input image."""
26
  inputs = processor(images=image, return_tensors="pt")
27
  outputs = model(**inputs)
28
 
29
- # Post-process results
30
  target_sizes = torch.tensor([image.size[::-1]])
31
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
32
 
33
- # Draw bounding boxes and labels
34
  draw = ImageDraw.Draw(image)
35
  for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
36
  x_min, y_min, x_max, y_max = box
37
  draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
38
  draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
39
-
40
  return image
41
 
42
- @app.post("/detect_accident")
43
- async def process_frame(file: UploadFile = File(...)):
44
- """API endpoint to process an uploaded frame."""
45
- try:
46
- # Read and preprocess image
47
- image = Image.open(io.BytesIO(await file.read()))
48
- image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
49
-
50
- # Detect accidents
51
- processed_image = detect_accident(image)
52
-
53
- # Save the processed image into bytes to send back
54
- img_byte_arr = io.BytesIO()
55
- processed_image.save(img_byte_arr, format="JPEG")
56
- img_byte_arr.seek(0)
57
-
58
- return JSONResponse(
59
- content={"status": "success", "message": "Frame processed successfully"},
60
- media_type="image/jpeg"
61
- )
62
- except Exception as e:
63
- return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
64
-
65
- # Run the app
66
- if __name__ == "__main__":
67
- import uvicorn
68
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
 
2
  from transformers import DetrImageProcessor, DetrForObjectDetection
3
  from PIL import Image, ImageDraw
 
4
  import torch
5
+ import cv2
6
+ import numpy as np
7
 
8
+ # Load model and processor
 
 
 
 
 
 
 
 
 
 
 
9
  model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
10
  processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
11
 
12
+ # Function to detect accidents in an image
13
  def detect_accident(image):
 
14
  inputs = processor(images=image, return_tensors="pt")
15
  outputs = model(**inputs)
16
 
17
+ # Post-process the results
18
  target_sizes = torch.tensor([image.size[::-1]])
19
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
20
 
21
+ # Draw boxes and labels on the image
22
  draw = ImageDraw.Draw(image)
23
  for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
24
  x_min, y_min, x_max, y_max = box
25
  draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
26
  draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
27
+
28
  return image
29
 
30
+ # Function to detect accidents frame-by-frame in a video
31
+ def detect_accident_in_video(video_path):
32
+ cap = cv2.VideoCapture(video_path)
33
+ frames = []
34
+ while True:
35
+ ret, frame = cap.read()
36
+ if not ret:
37
+ break
38
+
39
+ # Convert frame to PIL Image
40
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
41
+ pil_frame = Image.fromarray(frame_rgb)
42
+
43
+ # Run accident detection on the frame
44
+ processed_frame = detect_accident(pil_frame)
45
+
46
+ # Convert PIL image back to numpy array for video
47
+ frames.append(np.array(processed_frame))
48
+
49
+ cap.release()
50
+
51
+ # Save processed frames as output video
52
+ height, width, _ = frames[0].shape
53
+ out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 10, (width, height))
54
+ for frame in frames:
55
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
56
+ out.release()
57
+
58
+ return "output.mp4"
59
+
60
+ # Gradio app interface
61
+ with gr.Blocks() as interface:
62
+ gr.Markdown("# Traffic Accident Detection")
63
+ gr.Markdown(
64
+ "Upload an image or video to detect traffic accidents using the DETR model. "
65
+ "For videos, the system processes frame by frame and outputs a new video with accident detection."
66
+ )
67
+
68
+ # Input components
69
+ with gr.Tab("Image Input"):
70
+ image_input = gr.Image(type="pil", label="Upload Image")
71
+ image_output = gr.Image(type="pil", label="Detection Output")
72
+ image_button = gr.Button("Detect Accidents in Image")
73
+
74
+ with gr.Tab("Video Input"):
75
+ video_input = gr.Video(label="Upload Video")
76
+ video_output = gr.Video(label="Processed Video")
77
+ video_button = gr.Button("Detect Accidents in Video")
78
+
79
+ # Define behaviors
80
+ image_button.click(fn=detect_accident, inputs=image_input, outputs=image_output)
81
+ video_button.click(fn=detect_accident_in_video, inputs=video_input, outputs=video_output)
82
+
83
+ interface.launch()
84
+
85
+
86
+
87
+ # from fastapi import FastAPI, File, UploadFile
88
+ # from fastapi.responses import JSONResponse
89
+ # from fastapi.middleware.cors import CORSMiddleware
90
+ # from transformers import DetrImageProcessor, DetrForObjectDetection
91
+ # from PIL import Image, ImageDraw
92
+ # import io
93
+ # import torch
94
+
95
+ # # Initialize FastAPI app
96
+ # app = FastAPI()
97
+
98
+ # # Add CORS middleware to allow communication with external clients
99
+ # app.add_middleware(
100
+ # CORSMiddleware,
101
+ # allow_origins=["*"], # Change this to the specific domain in production
102
+ # allow_methods=["*"],
103
+ # allow_headers=["*"],
104
+ # )
105
+
106
+ # # Load the model and processor
107
+ # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
108
+ # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
109
+
110
+ # def detect_accident(image):
111
+ # """Runs accident detection on the input image."""
112
+ # inputs = processor(images=image, return_tensors="pt")
113
+ # outputs = model(**inputs)
114
+
115
+ # # Post-process results
116
+ # target_sizes = torch.tensor([image.size[::-1]])
117
+ # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
118
+
119
+ # # Draw bounding boxes and labels
120
+ # draw = ImageDraw.Draw(image)
121
+ # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
122
+ # x_min, y_min, x_max, y_max = box
123
+ # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
124
+ # draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
125
+
126
+ # return image
127
+
128
+ # @app.post("/detect_accident")
129
+ # async def process_frame(file: UploadFile = File(...)):
130
+ # """API endpoint to process an uploaded frame."""
131
+ # try:
132
+ # # Read and preprocess image
133
+ # image = Image.open(io.BytesIO(await file.read()))
134
+ # image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
135
+
136
+ # # Detect accidents
137
+ # processed_image = detect_accident(image)
138
+
139
+ # # Save the processed image into bytes to send back
140
+ # img_byte_arr = io.BytesIO()
141
+ # processed_image.save(img_byte_arr, format="JPEG")
142
+ # img_byte_arr.seek(0)
143
+
144
+ # return JSONResponse(
145
+ # content={"status": "success", "message": "Frame processed successfully"},
146
+ # media_type="image/jpeg"
147
+ # )
148
+ # except Exception as e:
149
+ # return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
150
+
151
+ # # Run the app
152
+ # if __name__ == "__main__":
153
+ # import uvicorn
154
+ # uvicorn.run(app, host="0.0.0.0", port=8000)