boostedhug commited on
Commit
1368d21
·
verified ·
1 Parent(s): a7f5a45
Files changed (1) hide show
  1. app.py +127 -127
app.py CHANGED
@@ -1,152 +1,152 @@
1
- # from fastapi import FastAPI, File, UploadFile
2
- # from fastapi.responses import JSONResponse
3
- # from fastapi.middleware.cors import CORSMiddleware
4
- # from transformers import DetrImageProcessor, DetrForObjectDetection
5
- # from PIL import Image, ImageDraw
6
- # import io
7
- # import torch
8
-
9
- # # Initialize FastAPI app
10
- # app = FastAPI()
11
-
12
- # # Add CORS middleware to allow communication with external clients
13
- # app.add_middleware(
14
- # CORSMiddleware,
15
- # allow_origins=["*"], # Change this to the specific domain in production
16
- # allow_methods=["*"],
17
- # allow_headers=["*"],
18
- # )
19
-
20
- # # Load the model and processor
21
- # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
22
- # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
23
-
24
- # def detect_accident(image):
25
- # """Runs accident detection on the input image."""
26
- # inputs = processor(images=image, return_tensors="pt")
27
- # outputs = model(**inputs)
28
-
29
- # # Post-process results
30
- # target_sizes = torch.tensor([image.size[::-1]])
31
- # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
32
-
33
- # # Draw bounding boxes and labels
34
- # draw = ImageDraw.Draw(image)
35
- # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
36
- # x_min, y_min, x_max, y_max = box
37
- # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
38
- # draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
39
-
40
- # return image
41
-
42
- # @app.post("/detect_accident")
43
- # async def process_frame(file: UploadFile = File(...)):
44
- # """API endpoint to process an uploaded frame."""
45
- # try:
46
- # # Read and preprocess image
47
- # image = Image.open(io.BytesIO(await file.read()))
48
- # image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
49
-
50
- # # Detect accidents
51
- # processed_image = detect_accident(image)
52
-
53
- # # Save the processed image into bytes to send back
54
- # img_byte_arr = io.BytesIO()
55
- # processed_image.save(img_byte_arr, format="JPEG")
56
- # img_byte_arr.seek(0)
57
-
58
- # return JSONResponse(
59
- # content={"status": "success", "message": "Frame processed successfully"},
60
- # media_type="image/jpeg"
61
- # )
62
- # except Exception as e:
63
- # return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
64
-
65
- # # Run the app
66
- # if __name__ == "__main__":
67
- # import uvicorn
68
- # uvicorn.run(app, host="0.0.0.0", port=8000)
69
-
70
- import gradio as gr
71
  from transformers import DetrImageProcessor, DetrForObjectDetection
72
  from PIL import Image, ImageDraw
 
73
  import torch
74
- import cv2
75
- import numpy as np
76
 
77
- # Load model and processor
 
 
 
 
 
 
 
 
 
 
 
78
  model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
79
  processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
80
 
81
- # Function to detect accidents in an image
82
  def detect_accident(image):
 
83
  inputs = processor(images=image, return_tensors="pt")
84
  outputs = model(**inputs)
85
 
86
- # Post-process the results
87
  target_sizes = torch.tensor([image.size[::-1]])
88
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
89
 
90
- # Draw boxes and labels on the image
91
  draw = ImageDraw.Draw(image)
92
  for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
93
  x_min, y_min, x_max, y_max = box
94
  draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
95
  draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
96
-
97
  return image
98
 
99
- # Function to detect accidents frame-by-frame in a video
100
- def detect_accident_in_video(video_path):
101
- cap = cv2.VideoCapture(video_path)
102
- frames = []
103
- while True:
104
- ret, frame = cap.read()
105
- if not ret:
106
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- # Convert frame to PIL Image
109
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
110
- pil_frame = Image.fromarray(frame_rgb)
111
-
112
- # Run accident detection on the frame
113
- processed_frame = detect_accident(pil_frame)
114
-
115
- # Convert PIL image back to numpy array for video
116
- frames.append(np.array(processed_frame))
117
-
118
- cap.release()
119
-
120
- # Save processed frames as output video
121
- height, width, _ = frames[0].shape
122
- out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 10, (width, height))
123
- for frame in frames:
124
- out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
125
- out.release()
126
-
127
- return "output.mp4"
128
-
129
- # Gradio app interface
130
- with gr.Blocks() as interface:
131
- gr.Markdown("# Traffic Accident Detection")
132
- gr.Markdown(
133
- "Upload an image or video to detect traffic accidents using the DETR model. "
134
- "For videos, the system processes frame by frame and outputs a new video with accident detection."
135
- )
136
-
137
- # Input components
138
- with gr.Tab("Image Input"):
139
- image_input = gr.Image(type="pil", label="Upload Image")
140
- image_output = gr.Image(type="pil", label="Detection Output")
141
- image_button = gr.Button("Detect Accidents in Image")
142
 
143
- with gr.Tab("Video Input"):
144
- video_input = gr.Video(label="Upload Video")
145
- video_output = gr.Video(label="Processed Video")
146
- video_button = gr.Button("Detect Accidents in Video")
147
 
148
- # Define behaviors
149
- image_button.click(fn=detect_accident, inputs=image_input, outputs=image_output)
150
- video_button.click(fn=detect_accident_in_video, inputs=video_input, outputs=video_output)
151
 
152
- interface.launch()
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
  from PIL import Image, ImageDraw
6
+ import io
7
  import torch
 
 
8
 
9
+ # Initialize FastAPI app
10
+ app = FastAPI()
11
+
12
+ # Add CORS middleware to allow communication with external clients
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"], # Change this to the specific domain in production
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ # Load the model and processor
21
  model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
22
  processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
23
 
 
24
  def detect_accident(image):
25
+ """Runs accident detection on the input image."""
26
  inputs = processor(images=image, return_tensors="pt")
27
  outputs = model(**inputs)
28
 
29
+ # Post-process results
30
  target_sizes = torch.tensor([image.size[::-1]])
31
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
32
 
33
+ # Draw bounding boxes and labels
34
  draw = ImageDraw.Draw(image)
35
  for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
36
  x_min, y_min, x_max, y_max = box
37
  draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
38
  draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
39
+
40
  return image
41
 
42
+ @app.post("/detect_accident")
43
+ async def process_frame(file: UploadFile = File(...)):
44
+ """API endpoint to process an uploaded frame."""
45
+ try:
46
+ # Read and preprocess image
47
+ image = Image.open(io.BytesIO(await file.read()))
48
+ image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
49
+
50
+ # Detect accidents
51
+ processed_image = detect_accident(image)
52
+
53
+ # Save the processed image into bytes to send back
54
+ img_byte_arr = io.BytesIO()
55
+ processed_image.save(img_byte_arr, format="JPEG")
56
+ img_byte_arr.seek(0)
57
+
58
+ return JSONResponse(
59
+ content={"status": "success", "message": "Frame processed successfully"},
60
+ media_type="image/jpeg"
61
+ )
62
+ except Exception as e:
63
+ return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
64
+
65
+ # Run the app
66
+ if __name__ == "__main__":
67
+ import uvicorn
68
+ uvicorn.run(app, host="0.0.0.0", port=8000)
69
+
70
+ # import gradio as gr
71
+ # from transformers import DetrImageProcessor, DetrForObjectDetection
72
+ # from PIL import Image, ImageDraw
73
+ # import torch
74
+ # import cv2
75
+ # import numpy as np
76
+
77
+ # # Load model and processor
78
+ # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
79
+ # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
80
+
81
+ # # Function to detect accidents in an image
82
+ # def detect_accident(image):
83
+ # inputs = processor(images=image, return_tensors="pt")
84
+ # outputs = model(**inputs)
85
+
86
+ # # Post-process the results
87
+ # target_sizes = torch.tensor([image.size[::-1]])
88
+ # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
89
+
90
+ # # Draw boxes and labels on the image
91
+ # draw = ImageDraw.Draw(image)
92
+ # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
93
+ # x_min, y_min, x_max, y_max = box
94
+ # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
95
+ # draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
96
+
97
+ # return image
98
+
99
+ # # Function to detect accidents frame-by-frame in a video
100
+ # def detect_accident_in_video(video_path):
101
+ # cap = cv2.VideoCapture(video_path)
102
+ # frames = []
103
+ # while True:
104
+ # ret, frame = cap.read()
105
+ # if not ret:
106
+ # break
107
 
108
+ # # Convert frame to PIL Image
109
+ # frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
110
+ # pil_frame = Image.fromarray(frame_rgb)
111
+
112
+ # # Run accident detection on the frame
113
+ # processed_frame = detect_accident(pil_frame)
114
+
115
+ # # Convert PIL image back to numpy array for video
116
+ # frames.append(np.array(processed_frame))
117
+
118
+ # cap.release()
119
+
120
+ # # Save processed frames as output video
121
+ # height, width, _ = frames[0].shape
122
+ # out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 10, (width, height))
123
+ # for frame in frames:
124
+ # out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
125
+ # out.release()
126
+
127
+ # return "output.mp4"
128
+
129
+ # # Gradio app interface
130
+ # with gr.Blocks() as interface:
131
+ # gr.Markdown("# Traffic Accident Detection")
132
+ # gr.Markdown(
133
+ # "Upload an image or video to detect traffic accidents using the DETR model. "
134
+ # "For videos, the system processes frame by frame and outputs a new video with accident detection."
135
+ # )
136
+
137
+ # # Input components
138
+ # with gr.Tab("Image Input"):
139
+ # image_input = gr.Image(type="pil", label="Upload Image")
140
+ # image_output = gr.Image(type="pil", label="Detection Output")
141
+ # image_button = gr.Button("Detect Accidents in Image")
142
 
143
+ # with gr.Tab("Video Input"):
144
+ # video_input = gr.Video(label="Upload Video")
145
+ # video_output = gr.Video(label="Processed Video")
146
+ # video_button = gr.Button("Detect Accidents in Video")
147
 
148
+ # # Define behaviors
149
+ # image_button.click(fn=detect_accident, inputs=image_input, outputs=image_output)
150
+ # video_button.click(fn=detect_accident_in_video, inputs=video_input, outputs=video_output)
151
 
152
+ # interface.launch()