boostedhug commited on
Commit
936dcac
·
verified ·
1 Parent(s): 3f13787

From FastAPI -> GradioAPI

Browse files
Files changed (1) hide show
  1. app.py +93 -40
app.py CHANGED
@@ -1,22 +1,8 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import StreamingResponse, JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
5
  from PIL import Image, ImageDraw
6
- import io
7
  import torch
8
 
9
- # Initialize FastAPI app
10
- app = FastAPI()
11
-
12
- # Add CORS middleware to allow communication with external clients
13
- app.add_middleware(
14
- CORSMiddleware,
15
- allow_origins=["*"], # Change this to specific domains in production
16
- allow_methods=["*"],
17
- allow_headers=["*"],
18
- )
19
-
20
  # Load the model and processor
21
  model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
22
  processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
@@ -40,31 +26,98 @@ def detect_accident(image):
40
 
41
  return image
42
 
43
- @app.post("/detect_accident")
44
- async def process_frame(file: UploadFile = File(...)):
45
- """API endpoint to process an uploaded frame."""
46
- try:
47
- # Read and preprocess image
48
- image = Image.open(io.BytesIO(await file.read()))
49
- image = image.convert("RGB") # Ensure compatibility with the model
50
-
51
- # Detect accidents
52
- processed_image = detect_accident(image)
53
-
54
- # Save the processed image into bytes to send back
55
- img_byte_arr = io.BytesIO()
56
- processed_image.save(img_byte_arr, format="JPEG")
57
- img_byte_arr.seek(0)
58
-
59
- # Return the image as a streaming response
60
- return StreamingResponse(img_byte_arr, media_type="image/jpeg")
61
- except Exception as e:
62
- return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
63
-
64
- # Run the app
65
- if __name__ == "__main__":
66
- import uvicorn
67
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
 
 
1
+ import gradio as gr
 
 
2
  from transformers import DetrImageProcessor, DetrForObjectDetection
3
  from PIL import Image, ImageDraw
 
4
  import torch
5
 
 
 
 
 
 
 
 
 
 
 
 
6
  # Load the model and processor
7
  model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
8
  processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
 
26
 
27
  return image
28
 
29
+ # Define the Gradio interface
30
+ def process_image(image):
31
+ processed_image = detect_accident(image)
32
+ return processed_image
33
+
34
+ # Launch the Gradio app
35
+ interface = gr.Interface(fn=process_image, inputs=gr.inputs.Image(type="pil"), outputs=gr.outputs.Image(type="pil"))
36
+ interface.launch(server_name="0.0.0.0", server_port=8000)
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+ # from fastapi import FastAPI, File, UploadFile
55
+ # from fastapi.responses import StreamingResponse, JSONResponse
56
+ # from fastapi.middleware.cors import CORSMiddleware
57
+ # from transformers import DetrImageProcessor, DetrForObjectDetection
58
+ # from PIL import Image, ImageDraw
59
+ # import io
60
+ # import torch
61
+
62
+ # # Initialize FastAPI app
63
+ # app = FastAPI()
64
+
65
+ # # Add CORS middleware to allow communication with external clients
66
+ # app.add_middleware(
67
+ # CORSMiddleware,
68
+ # allow_origins=["*"], # Change this to specific domains in production
69
+ # allow_methods=["*"],
70
+ # allow_headers=["*"],
71
+ # )
72
+
73
+ # # Load the model and processor
74
+ # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
75
+ # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
76
+
77
+ # def detect_accident(image):
78
+ # """Runs accident detection on the input image."""
79
+ # inputs = processor(images=image, return_tensors="pt")
80
+ # outputs = model(**inputs)
81
+
82
+ # # Post-process results
83
+ # target_sizes = torch.tensor([image.size[::-1]])
84
+ # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
85
+
86
+ # # Draw bounding boxes and labels
87
+ # draw = ImageDraw.Draw(image)
88
+ # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
89
+ # x_min, y_min, x_max, y_max = box
90
+ # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
91
+ # label_name = model.config.id2label[label.item()]
92
+ # draw.text((x_min, y_min), f"{label_name}: {score:.2f}", fill="red")
93
+
94
+ # return image
95
+
96
+ # @app.post("/detect_accident")
97
+ # async def process_frame(file: UploadFile = File(...)):
98
+ # """API endpoint to process an uploaded frame."""
99
+ # try:
100
+ # # Read and preprocess image
101
+ # image = Image.open(io.BytesIO(await file.read()))
102
+ # image = image.convert("RGB") # Ensure compatibility with the model
103
+
104
+ # # Detect accidents
105
+ # processed_image = detect_accident(image)
106
+
107
+ # # Save the processed image into bytes to send back
108
+ # img_byte_arr = io.BytesIO()
109
+ # processed_image.save(img_byte_arr, format="JPEG")
110
+ # img_byte_arr.seek(0)
111
+
112
+ # # Return the image as a streaming response
113
+ # return StreamingResponse(img_byte_arr, media_type="image/jpeg")
114
+ # except Exception as e:
115
+ # return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
116
+
117
+ # # Run the app
118
+ # if __name__ == "__main__":
119
+ # import uvicorn
120
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
121
 
122
 
123