boostedhug commited on
Commit
6e8e295
·
verified ·
1 Parent(s): b1c89e5
Files changed (1) hide show
  1. app.py +56 -114
app.py CHANGED
@@ -1,129 +1,71 @@
1
- import gradio as gr
 
 
2
  from transformers import DetrImageProcessor, DetrForObjectDetection
3
  from PIL import Image, ImageDraw
 
4
  import torch
5
 
6
- # Load the DETR model and processor
7
- model_name = "hilmantm/detr-traffic-accident-detection"
8
- processor = DetrImageProcessor.from_pretrained(model_name)
9
- model = DetrForObjectDetection.from_pretrained(model_name)
10
 
11
- def detect_accident(image):
12
- """Process an image and detect traffic accidents using the DETR model."""
13
- # Preprocess the input image
14
- inputs = processor(images=image, return_tensors="pt", size={"longest_edge": 800})
15
-
16
- # Get model predictions
17
- outputs = model(**inputs)
18
-
19
- # Post-process predictions to extract bounding boxes and labels
20
- target_sizes = torch.tensor([image.size[::-1]]) # Image size in (height, width)
21
- results = processor.post_process_object_detection(
22
- outputs, target_sizes=target_sizes, threshold=0.9
23
- )[0]
24
-
25
- # Draw bounding boxes and labels on the image
26
- draw = ImageDraw.Draw(image)
27
- for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
28
- box = [int(b) for b in box]
29
- label_text = f"{model.config.id2label[label]}: {score:.2f}"
30
- draw.rectangle(box, outline="red", width=3)
31
- draw.text((box[0], box[1]), label_text, fill="red")
32
-
33
- return image
34
-
35
- # Define the Gradio interface
36
- iface = gr.Interface(
37
- fn=detect_accident,
38
- inputs=gr.Image(type="pil"),
39
- outputs=gr.Image(type="pil"),
40
- title="Traffic Accident Detection",
41
- description="Upload an image to detect traffic accidents using the DETR model."
42
  )
43
 
44
- # Launch the app
45
- iface.launch()
46
-
47
-
48
-
49
 
 
 
 
 
50
 
 
 
 
51
 
52
-
53
-
54
-
55
-
56
-
57
-
58
-
59
- # from fastapi import FastAPI, File, UploadFile
60
- # from fastapi.responses import JSONResponse
61
- # from fastapi.middleware.cors import CORSMiddleware
62
- # from transformers import DetrImageProcessor, DetrForObjectDetection
63
- # from PIL import Image, ImageDraw
64
- # import io
65
- # import torch
66
-
67
- # # Initialize FastAPI app
68
- # app = FastAPI()
69
-
70
- # # Add CORS middleware to allow communication with external clients
71
- # app.add_middleware(
72
- # CORSMiddleware,
73
- # allow_origins=["*"], # Change this to the specific domain in production
74
- # allow_methods=["*"],
75
- # allow_headers=["*"],
76
- # )
77
-
78
- # # Load the model and processor
79
- # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
80
- # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
81
-
82
- # def detect_accident(image):
83
- # """Runs accident detection on the input image."""
84
- # inputs = processor(images=image, return_tensors="pt")
85
- # outputs = model(**inputs)
86
-
87
- # # Post-process results
88
- # target_sizes = torch.tensor([image.size[::-1]])
89
- # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
90
-
91
- # # Draw bounding boxes and labels
92
- # draw = ImageDraw.Draw(image)
93
- # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
94
- # x_min, y_min, x_max, y_max = box
95
- # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
96
- # draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
97
 
98
- # return image
99
 
100
- # @app.post("/detect_accident")
101
- # async def process_frame(file: UploadFile = File(...)):
102
- # """API endpoint to process an uploaded frame."""
103
- # try:
104
- # # Read and preprocess image
105
- # image = Image.open(io.BytesIO(await file.read()))
106
- # image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
107
-
108
- # # Detect accidents
109
- # processed_image = detect_accident(image)
110
-
111
- # # Save the processed image into bytes to send back
112
- # img_byte_arr = io.BytesIO()
113
- # processed_image.save(img_byte_arr, format="JPEG")
114
- # img_byte_arr.seek(0)
115
-
116
- # return JSONResponse(
117
- # content={"status": "success", "message": "Frame processed successfully"},
118
- # media_type="image/jpeg"
119
- # )
120
- # except Exception as e:
121
- # return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
122
-
123
- # # Run the app
124
- # if __name__ == "__main__":
125
- # import uvicorn
126
- # uvicorn.run(app, host="0.0.0.0", port=8000)
127
 
128
 
129
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
  from PIL import Image, ImageDraw
6
+ import io
7
  import torch
8
 
9
+ # Initialize FastAPI app
10
+ app = FastAPI()
 
 
11
 
12
+ # Add CORS middleware to allow communication with external clients
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"], # Change this to the specific domain in production
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
19
 
20
+ # Load the model and processor
21
+ model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
22
+ processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
 
 
23
 
24
+ def detect_accident(image):
25
+ """Runs accident detection on the input image."""
26
+ inputs = processor(images=image, return_tensors="pt")
27
+ outputs = model(**inputs)
28
 
29
+ # Post-process results
30
+ target_sizes = torch.tensor([image.size[::-1]])
31
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
32
 
33
+ # Draw bounding boxes and labels
34
+ draw = ImageDraw.Draw(image)
35
+ for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
36
+ x_min, y_min, x_max, y_max = box
37
+ draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
38
+ draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ return image
41
 
42
+ @app.post("/detect_accident")
43
+ async def process_frame(file: UploadFile = File(...)):
44
+ """API endpoint to process an uploaded frame."""
45
+ try:
46
+ # Read and preprocess image
47
+ image = Image.open(io.BytesIO(await file.read()))
48
+ image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
49
+
50
+ # Detect accidents
51
+ processed_image = detect_accident(image)
52
+
53
+ # Save the processed image into bytes to send back
54
+ img_byte_arr = io.BytesIO()
55
+ processed_image.save(img_byte_arr, format="JPEG")
56
+ img_byte_arr.seek(0)
57
+
58
+ return JSONResponse(
59
+ content={"status": "success", "message": "Frame processed successfully"},
60
+ media_type="image/jpeg"
61
+ )
62
+ except Exception as e:
63
+ return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
64
+
65
+ # Run the app
66
+ if __name__ == "__main__":
67
+ import uvicorn
68
+ uvicorn.run(app, host="0.0.0.0", port=8000)
69
 
70
 
71