boostedhug commited on
Commit
04b799b
·
verified ·
1 Parent(s): 7a62d82

re-added api

Browse files
Files changed (1) hide show
  1. app.py +110 -112
app.py CHANGED
@@ -1,154 +1,152 @@
1
- import gradio as gr
 
 
2
  from transformers import DetrImageProcessor, DetrForObjectDetection
3
  from PIL import Image, ImageDraw
 
4
  import torch
5
- import cv2
6
- import numpy as np
7
 
8
- # Load model and processor
 
 
 
 
 
 
 
 
 
 
 
9
  model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
10
  processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
11
 
12
- # Function to detect accidents in an image
13
  def detect_accident(image):
 
14
  inputs = processor(images=image, return_tensors="pt")
15
  outputs = model(**inputs)
16
 
17
- # Post-process the results
18
  target_sizes = torch.tensor([image.size[::-1]])
19
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
20
 
21
- # Draw boxes and labels on the image
22
  draw = ImageDraw.Draw(image)
23
  for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
24
  x_min, y_min, x_max, y_max = box
25
  draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
26
  draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
27
-
28
- return image
29
-
30
- # Function to detect accidents frame-by-frame in a video
31
- def detect_accident_in_video(video_path):
32
- cap = cv2.VideoCapture(video_path)
33
- frames = []
34
- while True:
35
- ret, frame = cap.read()
36
- if not ret:
37
- break
38
-
39
- # Convert frame to PIL Image
40
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
41
- pil_frame = Image.fromarray(frame_rgb)
42
-
43
- # Run accident detection on the frame
44
- processed_frame = detect_accident(pil_frame)
45
-
46
- # Convert PIL image back to numpy array for video
47
- frames.append(np.array(processed_frame))
48
-
49
- cap.release()
50
-
51
- # Save processed frames as output video
52
- height, width, _ = frames[0].shape
53
- out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 10, (width, height))
54
- for frame in frames:
55
- out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
56
- out.release()
57
-
58
- return "output.mp4"
59
-
60
- # Gradio app interface
61
- with gr.Blocks() as interface:
62
- gr.Markdown("# Traffic Accident Detection")
63
- gr.Markdown(
64
- "Upload an image or video to detect traffic accidents using the DETR model. "
65
- "For videos, the system processes frame by frame and outputs a new video with accident detection."
66
- )
67
-
68
- # Input components
69
- with gr.Tab("Image Input"):
70
- image_input = gr.Image(type="pil", label="Upload Image")
71
- image_output = gr.Image(type="pil", label="Detection Output")
72
- image_button = gr.Button("Detect Accidents in Image")
73
 
74
- with gr.Tab("Video Input"):
75
- video_input = gr.Video(label="Upload Video")
76
- video_output = gr.Video(label="Processed Video")
77
- video_button = gr.Button("Detect Accidents in Video")
78
-
79
- # Define behaviors
80
- image_button.click(fn=detect_accident, inputs=image_input, outputs=image_output)
81
- video_button.click(fn=detect_accident_in_video, inputs=video_input, outputs=video_output)
82
-
83
- interface.launch()
84
-
85
-
86
 
87
- # from fastapi import FastAPI, File, UploadFile
88
- # from fastapi.responses import JSONResponse
89
- # from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  # from transformers import DetrImageProcessor, DetrForObjectDetection
91
  # from PIL import Image, ImageDraw
92
- # import io
93
  # import torch
 
 
94
 
95
- # # Initialize FastAPI app
96
- # app = FastAPI()
97
-
98
- # # Add CORS middleware to allow communication with external clients
99
- # app.add_middleware(
100
- # CORSMiddleware,
101
- # allow_origins=["*"], # Change this to the specific domain in production
102
- # allow_methods=["*"],
103
- # allow_headers=["*"],
104
- # )
105
-
106
- # # Load the model and processor
107
  # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
108
  # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
109
 
 
110
  # def detect_accident(image):
111
- # """Runs accident detection on the input image."""
112
  # inputs = processor(images=image, return_tensors="pt")
113
  # outputs = model(**inputs)
114
 
115
- # # Post-process results
116
  # target_sizes = torch.tensor([image.size[::-1]])
117
  # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
118
 
119
- # # Draw bounding boxes and labels
120
  # draw = ImageDraw.Draw(image)
121
  # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
122
  # x_min, y_min, x_max, y_max = box
123
  # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
124
  # draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
125
-
126
  # return image
127
 
128
- # @app.post("/detect_accident")
129
- # async def process_frame(file: UploadFile = File(...)):
130
- # """API endpoint to process an uploaded frame."""
131
- # try:
132
- # # Read and preprocess image
133
- # image = Image.open(io.BytesIO(await file.read()))
134
- # image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
135
-
136
- # # Detect accidents
137
- # processed_image = detect_accident(image)
138
-
139
- # # Save the processed image into bytes to send back
140
- # img_byte_arr = io.BytesIO()
141
- # processed_image.save(img_byte_arr, format="JPEG")
142
- # img_byte_arr.seek(0)
143
-
144
- # return JSONResponse(
145
- # content={"status": "success", "message": "Frame processed successfully"},
146
- # media_type="image/jpeg"
147
- # )
148
- # except Exception as e:
149
- # return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
150
-
151
- # # Run the app
152
- # if __name__ == "__main__":
153
- # import uvicorn
154
- # uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
  from PIL import Image, ImageDraw
6
+ import io
7
  import torch
 
 
8
 
9
+ # Initialize FastAPI app
10
+ app = FastAPI()
11
+
12
+ # Add CORS middleware to allow communication with external clients
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"], # Change this to the specific domain in production
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ # Load the model and processor
21
  model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
22
  processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
23
 
 
24
  def detect_accident(image):
25
+ """Runs accident detection on the input image."""
26
  inputs = processor(images=image, return_tensors="pt")
27
  outputs = model(**inputs)
28
 
29
+ # Post-process results
30
  target_sizes = torch.tensor([image.size[::-1]])
31
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
32
 
33
+ # Draw bounding boxes and labels
34
  draw = ImageDraw.Draw(image)
35
  for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
36
  x_min, y_min, x_max, y_max = box
37
  draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
38
  draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ return image
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ @app.post("/detect_accident")
43
+ async def process_frame(file: UploadFile = File(...)):
44
+ """API endpoint to process an uploaded frame."""
45
+ try:
46
+ # Read and preprocess image
47
+ image = Image.open(io.BytesIO(await file.read()))
48
+ image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
49
+
50
+ # Detect accidents
51
+ processed_image = detect_accident(image)
52
+
53
+ # Save the processed image into bytes to send back
54
+ img_byte_arr = io.BytesIO()
55
+ processed_image.save(img_byte_arr, format="JPEG")
56
+ img_byte_arr.seek(0)
57
+
58
+ return JSONResponse(
59
+ content={"status": "success", "message": "Frame processed successfully"},
60
+ media_type="image/jpeg"
61
+ )
62
+ except Exception as e:
63
+ return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
64
+
65
+ # Run the app
66
+ if __name__ == "__main__":
67
+ import uvicorn
68
+ uvicorn.run(app, host="0.0.0.0", port=8000)
69
+
70
+ # import gradio as gr
71
  # from transformers import DetrImageProcessor, DetrForObjectDetection
72
  # from PIL import Image, ImageDraw
 
73
  # import torch
74
+ # import cv2
75
+ # import numpy as np
76
 
77
+ # # Load model and processor
 
 
 
 
 
 
 
 
 
 
 
78
  # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
79
  # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
80
 
81
+ # # Function to detect accidents in an image
82
  # def detect_accident(image):
 
83
  # inputs = processor(images=image, return_tensors="pt")
84
  # outputs = model(**inputs)
85
 
86
+ # # Post-process the results
87
  # target_sizes = torch.tensor([image.size[::-1]])
88
  # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
89
 
90
+ # # Draw boxes and labels on the image
91
  # draw = ImageDraw.Draw(image)
92
  # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
93
  # x_min, y_min, x_max, y_max = box
94
  # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
95
  # draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
96
+
97
  # return image
98
 
99
+ # # Function to detect accidents frame-by-frame in a video
100
+ # def detect_accident_in_video(video_path):
101
+ # cap = cv2.VideoCapture(video_path)
102
+ # frames = []
103
+ # while True:
104
+ # ret, frame = cap.read()
105
+ # if not ret:
106
+ # break
107
+
108
+ # # Convert frame to PIL Image
109
+ # frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
110
+ # pil_frame = Image.fromarray(frame_rgb)
111
+
112
+ # # Run accident detection on the frame
113
+ # processed_frame = detect_accident(pil_frame)
114
+
115
+ # # Convert PIL image back to numpy array for video
116
+ # frames.append(np.array(processed_frame))
117
+
118
+ # cap.release()
119
+
120
+ # # Save processed frames as output video
121
+ # height, width, _ = frames[0].shape
122
+ # out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 10, (width, height))
123
+ # for frame in frames:
124
+ # out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
125
+ # out.release()
126
+
127
+ # return "output.mp4"
128
+
129
+ # # Gradio app interface
130
+ # with gr.Blocks() as interface:
131
+ # gr.Markdown("# Traffic Accident Detection")
132
+ # gr.Markdown(
133
+ # "Upload an image or video to detect traffic accidents using the DETR model. "
134
+ # "For videos, the system processes frame by frame and outputs a new video with accident detection."
135
+ # )
136
+
137
+ # # Input components
138
+ # with gr.Tab("Image Input"):
139
+ # image_input = gr.Image(type="pil", label="Upload Image")
140
+ # image_output = gr.Image(type="pil", label="Detection Output")
141
+ # image_button = gr.Button("Detect Accidents in Image")
142
+
143
+ # with gr.Tab("Video Input"):
144
+ # video_input = gr.Video(label="Upload Video")
145
+ # video_output = gr.Video(label="Processed Video")
146
+ # video_button = gr.Button("Detect Accidents in Video")
147
+
148
+ # # Define behaviors
149
+ # image_button.click(fn=detect_accident, inputs=image_input, outputs=image_output)
150
+ # video_button.click(fn=detect_accident_in_video, inputs=video_input, outputs=video_output)
151
+
152
+ # interface.launch()