boostedhug commited on
Commit
a7f5a45
·
verified ·
1 Parent(s): 04b799b

re-added UI, test

Browse files
Files changed (1) hide show
  1. app.py +127 -127
app.py CHANGED
@@ -1,152 +1,152 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from transformers import DetrImageProcessor, DetrForObjectDetection
5
- from PIL import Image, ImageDraw
6
- import io
7
- import torch
8
-
9
- # Initialize FastAPI app
10
- app = FastAPI()
11
-
12
- # Add CORS middleware to allow communication with external clients
13
- app.add_middleware(
14
- CORSMiddleware,
15
- allow_origins=["*"], # Change this to the specific domain in production
16
- allow_methods=["*"],
17
- allow_headers=["*"],
18
- )
19
-
20
- # Load the model and processor
21
- model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
22
- processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
23
-
24
- def detect_accident(image):
25
- """Runs accident detection on the input image."""
26
- inputs = processor(images=image, return_tensors="pt")
27
- outputs = model(**inputs)
28
-
29
- # Post-process results
30
- target_sizes = torch.tensor([image.size[::-1]])
31
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
32
-
33
- # Draw bounding boxes and labels
34
- draw = ImageDraw.Draw(image)
35
- for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
36
- x_min, y_min, x_max, y_max = box
37
- draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
38
- draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
39
-
40
- return image
41
-
42
- @app.post("/detect_accident")
43
- async def process_frame(file: UploadFile = File(...)):
44
- """API endpoint to process an uploaded frame."""
45
- try:
46
- # Read and preprocess image
47
- image = Image.open(io.BytesIO(await file.read()))
48
- image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
49
-
50
- # Detect accidents
51
- processed_image = detect_accident(image)
52
-
53
- # Save the processed image into bytes to send back
54
- img_byte_arr = io.BytesIO()
55
- processed_image.save(img_byte_arr, format="JPEG")
56
- img_byte_arr.seek(0)
57
-
58
- return JSONResponse(
59
- content={"status": "success", "message": "Frame processed successfully"},
60
- media_type="image/jpeg"
61
- )
62
- except Exception as e:
63
- return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
64
-
65
- # Run the app
66
- if __name__ == "__main__":
67
- import uvicorn
68
- uvicorn.run(app, host="0.0.0.0", port=8000)
69
-
70
- # import gradio as gr
71
  # from transformers import DetrImageProcessor, DetrForObjectDetection
72
  # from PIL import Image, ImageDraw
 
73
  # import torch
74
- # import cv2
75
- # import numpy as np
76
 
77
- # # Load model and processor
 
 
 
 
 
 
 
 
 
 
 
78
  # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
79
  # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
80
 
81
- # # Function to detect accidents in an image
82
  # def detect_accident(image):
 
83
  # inputs = processor(images=image, return_tensors="pt")
84
  # outputs = model(**inputs)
85
 
86
- # # Post-process the results
87
  # target_sizes = torch.tensor([image.size[::-1]])
88
  # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
89
 
90
- # # Draw boxes and labels on the image
91
  # draw = ImageDraw.Draw(image)
92
  # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
93
  # x_min, y_min, x_max, y_max = box
94
  # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
95
  # draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
96
-
97
  # return image
98
 
99
- # # Function to detect accidents frame-by-frame in a video
100
- # def detect_accident_in_video(video_path):
101
- # cap = cv2.VideoCapture(video_path)
102
- # frames = []
103
- # while True:
104
- # ret, frame = cap.read()
105
- # if not ret:
106
- # break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- # # Convert frame to PIL Image
109
- # frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
110
- # pil_frame = Image.fromarray(frame_rgb)
111
-
112
- # # Run accident detection on the frame
113
- # processed_frame = detect_accident(pil_frame)
114
-
115
- # # Convert PIL image back to numpy array for video
116
- # frames.append(np.array(processed_frame))
117
-
118
- # cap.release()
119
-
120
- # # Save processed frames as output video
121
- # height, width, _ = frames[0].shape
122
- # out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 10, (width, height))
123
- # for frame in frames:
124
- # out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
125
- # out.release()
126
-
127
- # return "output.mp4"
128
-
129
- # # Gradio app interface
130
- # with gr.Blocks() as interface:
131
- # gr.Markdown("# Traffic Accident Detection")
132
- # gr.Markdown(
133
- # "Upload an image or video to detect traffic accidents using the DETR model. "
134
- # "For videos, the system processes frame by frame and outputs a new video with accident detection."
135
- # )
136
-
137
- # # Input components
138
- # with gr.Tab("Image Input"):
139
- # image_input = gr.Image(type="pil", label="Upload Image")
140
- # image_output = gr.Image(type="pil", label="Detection Output")
141
- # image_button = gr.Button("Detect Accidents in Image")
142
 
143
- # with gr.Tab("Video Input"):
144
- # video_input = gr.Video(label="Upload Video")
145
- # video_output = gr.Video(label="Processed Video")
146
- # video_button = gr.Button("Detect Accidents in Video")
147
 
148
- # # Define behaviors
149
- # image_button.click(fn=detect_accident, inputs=image_input, outputs=image_output)
150
- # video_button.click(fn=detect_accident_in_video, inputs=video_input, outputs=video_output)
151
 
152
- # interface.launch()
 
1
+ # from fastapi import FastAPI, File, UploadFile
2
+ # from fastapi.responses import JSONResponse
3
+ # from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # from transformers import DetrImageProcessor, DetrForObjectDetection
5
  # from PIL import Image, ImageDraw
6
+ # import io
7
  # import torch
 
 
8
 
9
+ # # Initialize FastAPI app
10
+ # app = FastAPI()
11
+
12
+ # # Add CORS middleware to allow communication with external clients
13
+ # app.add_middleware(
14
+ # CORSMiddleware,
15
+ # allow_origins=["*"], # Change this to the specific domain in production
16
+ # allow_methods=["*"],
17
+ # allow_headers=["*"],
18
+ # )
19
+
20
+ # # Load the model and processor
21
  # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
22
  # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
23
 
 
24
  # def detect_accident(image):
25
+ # """Runs accident detection on the input image."""
26
  # inputs = processor(images=image, return_tensors="pt")
27
  # outputs = model(**inputs)
28
 
29
+ # # Post-process results
30
  # target_sizes = torch.tensor([image.size[::-1]])
31
  # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
32
 
33
+ # # Draw bounding boxes and labels
34
  # draw = ImageDraw.Draw(image)
35
  # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
36
  # x_min, y_min, x_max, y_max = box
37
  # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
38
  # draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
39
+
40
  # return image
41
 
42
+ # @app.post("/detect_accident")
43
+ # async def process_frame(file: UploadFile = File(...)):
44
+ # """API endpoint to process an uploaded frame."""
45
+ # try:
46
+ # # Read and preprocess image
47
+ # image = Image.open(io.BytesIO(await file.read()))
48
+ # image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
49
+
50
+ # # Detect accidents
51
+ # processed_image = detect_accident(image)
52
+
53
+ # # Save the processed image into bytes to send back
54
+ # img_byte_arr = io.BytesIO()
55
+ # processed_image.save(img_byte_arr, format="JPEG")
56
+ # img_byte_arr.seek(0)
57
+
58
+ # return JSONResponse(
59
+ # content={"status": "success", "message": "Frame processed successfully"},
60
+ # media_type="image/jpeg"
61
+ # )
62
+ # except Exception as e:
63
+ # return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
64
+
65
+ # # Run the app
66
+ # if __name__ == "__main__":
67
+ # import uvicorn
68
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
69
+
70
+ import gradio as gr
71
+ from transformers import DetrImageProcessor, DetrForObjectDetection
72
+ from PIL import Image, ImageDraw
73
+ import torch
74
+ import cv2
75
+ import numpy as np
76
+
77
+ # Load model and processor
78
+ model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
79
+ processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
80
+
81
+ # Function to detect accidents in an image
82
+ def detect_accident(image):
83
+ inputs = processor(images=image, return_tensors="pt")
84
+ outputs = model(**inputs)
85
+
86
+ # Post-process the results
87
+ target_sizes = torch.tensor([image.size[::-1]])
88
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
89
+
90
+ # Draw boxes and labels on the image
91
+ draw = ImageDraw.Draw(image)
92
+ for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
93
+ x_min, y_min, x_max, y_max = box
94
+ draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
95
+ draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
96
+
97
+ return image
98
+
99
+ # Function to detect accidents frame-by-frame in a video
100
+ def detect_accident_in_video(video_path):
101
+ cap = cv2.VideoCapture(video_path)
102
+ frames = []
103
+ while True:
104
+ ret, frame = cap.read()
105
+ if not ret:
106
+ break
107
 
108
+ # Convert frame to PIL Image
109
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
110
+ pil_frame = Image.fromarray(frame_rgb)
111
+
112
+ # Run accident detection on the frame
113
+ processed_frame = detect_accident(pil_frame)
114
+
115
+ # Convert PIL image back to numpy array for video
116
+ frames.append(np.array(processed_frame))
117
+
118
+ cap.release()
119
+
120
+ # Save processed frames as output video
121
+ height, width, _ = frames[0].shape
122
+ out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 10, (width, height))
123
+ for frame in frames:
124
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
125
+ out.release()
126
+
127
+ return "output.mp4"
128
+
129
+ # Gradio app interface
130
+ with gr.Blocks() as interface:
131
+ gr.Markdown("# Traffic Accident Detection")
132
+ gr.Markdown(
133
+ "Upload an image or video to detect traffic accidents using the DETR model. "
134
+ "For videos, the system processes frame by frame and outputs a new video with accident detection."
135
+ )
136
+
137
+ # Input components
138
+ with gr.Tab("Image Input"):
139
+ image_input = gr.Image(type="pil", label="Upload Image")
140
+ image_output = gr.Image(type="pil", label="Detection Output")
141
+ image_button = gr.Button("Detect Accidents in Image")
142
 
143
+ with gr.Tab("Video Input"):
144
+ video_input = gr.Video(label="Upload Video")
145
+ video_output = gr.Video(label="Processed Video")
146
+ video_button = gr.Button("Detect Accidents in Video")
147
 
148
+ # Define behaviors
149
+ image_button.click(fn=detect_accident, inputs=image_input, outputs=image_output)
150
+ video_button.click(fn=detect_accident_in_video, inputs=video_input, outputs=video_output)
151
 
152
+ interface.launch()