boostedhug commited on
Commit
5b8e275
·
verified ·
1 Parent(s): 8e83b28

new version added

Browse files
Files changed (1) hide show
  1. app.py +124 -57
app.py CHANGED
@@ -1,71 +1,138 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
- from PIL import Image, ImageDraw
6
- import io
7
  import torch
8
 
9
- # Initialize FastAPI app
10
- app = FastAPI()
11
-
12
- # Add CORS middleware to allow communication with external clients
13
- app.add_middleware(
14
- CORSMiddleware,
15
- allow_origins=["*"], # Change this to the specific domain in production
16
- allow_methods=["*"],
17
- allow_headers=["*"],
18
- )
19
-
20
- # Load the model and processor
21
- model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
22
- processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
23
 
24
  def detect_accident(image):
25
- """Runs accident detection on the input image."""
 
26
  inputs = processor(images=image, return_tensors="pt")
 
 
27
  outputs = model(**inputs)
28
-
29
- # Post-process results
30
- target_sizes = torch.tensor([image.size[::-1]])
31
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
32
-
33
- # Draw bounding boxes and labels
 
 
34
  draw = ImageDraw.Draw(image)
35
- for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
36
- x_min, y_min, x_max, y_max = box
37
- draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
38
- draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
 
39
 
40
  return image
41
 
42
- @app.post("/detect_accident")
43
- async def process_frame(file: UploadFile = File(...)):
44
- """API endpoint to process an uploaded frame."""
45
- try:
46
- # Read and preprocess image
47
- image = Image.open(io.BytesIO(await file.read()))
48
- image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
49
-
50
- # Detect accidents
51
- processed_image = detect_accident(image)
52
-
53
- # Save the processed image into bytes to send back
54
- img_byte_arr = io.BytesIO()
55
- processed_image.save(img_byte_arr, format="JPEG")
56
- img_byte_arr.seek(0)
57
-
58
- return JSONResponse(
59
- content={"status": "success", "message": "Frame processed successfully"},
60
- media_type="image/jpeg"
61
- )
62
- except Exception as e:
63
- return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
64
-
65
- # Run the app
66
- if __name__ == "__main__":
67
- import uvicorn
68
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  # import gradio as gr
71
  # from transformers import DetrImageProcessor, DetrForObjectDetection
 
1
+ import gradio as gr
 
 
2
  from transformers import DetrImageProcessor, DetrForObjectDetection
3
+ from PIL import Image
 
4
  import torch
5
 
6
+ # Load the DETR model and processor
7
+ model_name = "hilmantm/detr-traffic-accident-detection"
8
+ processor = DetrImageProcessor.from_pretrained(model_name)
9
+ model = DetrForObjectDetection.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
10
 
11
  def detect_accident(image):
12
+ """Process an image and detect traffic accidents using the DETR model."""
13
+ # Preprocess the input image
14
  inputs = processor(images=image, return_tensors="pt")
15
+
16
+ # Get model predictions
17
  outputs = model(**inputs)
18
+
19
+ # Post-process predictions to extract bounding boxes and labels
20
+ target_sizes = torch.tensor([image.size[::-1]]) # Image size in (height, width)
21
+ results = processor.post_process_object_detection(
22
+ outputs, target_sizes=target_sizes, threshold=0.9
23
+ )[0]
24
+
25
+ # Draw bounding boxes and labels on the image
26
  draw = ImageDraw.Draw(image)
27
+ for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
28
+ box = [int(b) for b in box]
29
+ label_text = f"{model.config.id2label[label]}: {score:.2f}"
30
+ draw.rectangle(box, outline="red", width=3)
31
+ draw.text((box[0], box[1]), label_text, fill="red")
32
 
33
  return image
34
 
35
+ # Define the Gradio interface
36
+ iface = gr.Interface(
37
+ fn=detect_accident,
38
+ inputs=gr.inputs.Image(type="pil"),
39
+ outputs=gr.outputs.Image(type="pil"),
40
+ title="Traffic Accident Detection",
41
+ description="Upload an image to detect traffic accidents using the DETR model."
42
+ )
43
+
44
+ # Launch the app
45
+ iface.launch()
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+
57
+
58
+ # from fastapi import FastAPI, File, UploadFile
59
+ # from fastapi.responses import JSONResponse
60
+ # from fastapi.middleware.cors import CORSMiddleware
61
+ # from transformers import DetrImageProcessor, DetrForObjectDetection
62
+ # from PIL import Image, ImageDraw
63
+ # import io
64
+ # import torch
65
+
66
+ # # Initialize FastAPI app
67
+ # app = FastAPI()
68
+
69
+ # # Add CORS middleware to allow communication with external clients
70
+ # app.add_middleware(
71
+ # CORSMiddleware,
72
+ # allow_origins=["*"], # Change this to the specific domain in production
73
+ # allow_methods=["*"],
74
+ # allow_headers=["*"],
75
+ # )
76
+
77
+ # # Load the model and processor
78
+ # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
79
+ # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
80
+
81
+ # def detect_accident(image):
82
+ # """Runs accident detection on the input image."""
83
+ # inputs = processor(images=image, return_tensors="pt")
84
+ # outputs = model(**inputs)
85
+
86
+ # # Post-process results
87
+ # target_sizes = torch.tensor([image.size[::-1]])
88
+ # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
89
+
90
+ # # Draw bounding boxes and labels
91
+ # draw = ImageDraw.Draw(image)
92
+ # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
93
+ # x_min, y_min, x_max, y_max = box
94
+ # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
95
+ # draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
96
+
97
+ # return image
98
+
99
+ # @app.post("/detect_accident")
100
+ # async def process_frame(file: UploadFile = File(...)):
101
+ # """API endpoint to process an uploaded frame."""
102
+ # try:
103
+ # # Read and preprocess image
104
+ # image = Image.open(io.BytesIO(await file.read()))
105
+ # image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
106
+
107
+ # # Detect accidents
108
+ # processed_image = detect_accident(image)
109
+
110
+ # # Save the processed image into bytes to send back
111
+ # img_byte_arr = io.BytesIO()
112
+ # processed_image.save(img_byte_arr, format="JPEG")
113
+ # img_byte_arr.seek(0)
114
+
115
+ # return JSONResponse(
116
+ # content={"status": "success", "message": "Frame processed successfully"},
117
+ # media_type="image/jpeg"
118
+ # )
119
+ # except Exception as e:
120
+ # return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
121
+
122
+ # # Run the app
123
+ # if __name__ == "__main__":
124
+ # import uvicorn
125
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
 
137
  # import gradio as gr
138
  # from transformers import DetrImageProcessor, DetrForObjectDetection