Robertooo commited on
Commit
921c700
·
verified ·
1 Parent(s): ed9ba85

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import requests
5
+ import torch
6
+ import math
7
+ from PIL import Image
8
+ import io
9
+ from typing import List, Dict, Union, Tuple
10
+ from tempfile import NamedTemporaryFile
11
+ import logging
12
+ import matplotlib.pyplot as plt
13
+ import gradio as gr
14
+ from PIL import ImageDraw
15
+
16
+ # Load environment variables
17
+ # load_dotenv()
18
+
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ class YOLODetector:
24
+ def __init__(self, api_url: str, stride: int = 32):
25
+ self.api_url = api_url
26
+ self.stride = stride
27
+
28
+ def preprocess_image(self, image_path: str, target_size: Tuple[int, int] = (640, 640)) -> str:
29
+ img = cv2.imread(image_path)
30
+ img_size = self.check_img_size(target_size, s=self.stride)
31
+ img_resized = cv2.resize(img, img_size)
32
+ resized_image_path = '/tmp/resized_image.png'
33
+ cv2.imwrite(resized_image_path, img_resized)
34
+ return resized_image_path
35
+
36
+ @staticmethod
37
+ def make_divisible(x: int, divisor: int) -> int:
38
+ return math.ceil(x / divisor) * divisor
39
+
40
+ def check_img_size(self, imgsz: Union[int, List[int]], s: int = 32, floor: int = 0) -> Union[int, List[int]]:
41
+ if isinstance(imgsz, int):
42
+ new_size = max(self.make_divisible(imgsz, int(s)), floor)
43
+ else:
44
+ imgsz = list(imgsz)
45
+ new_size = [max(self.make_divisible(x, int(s)), floor) for x in imgsz]
46
+ if new_size != imgsz:
47
+ logger.warning(f'WARNING: Image size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
48
+ return new_size
49
+
50
+ def detect(self, image_path: str, conf_thres: float = 0.25, iou_thres: float = 0.45) -> Dict:
51
+ params = {'conf_thres': conf_thres, 'iou_thres': iou_thres}
52
+
53
+ try:
54
+ with open(image_path, 'rb') as image_file:
55
+ with NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
56
+ temp_file.write(image_file.read())
57
+ temp_file_path = temp_file.name
58
+
59
+ with open(temp_file_path, 'rb') as f:
60
+ files = {'file': ('image.png', f, 'image/png')}
61
+ response = requests.post(self.api_url, params=params, files=files, timeout=30)
62
+
63
+ response.raise_for_status()
64
+ results = response.json()
65
+ logger.info(f"Detection results: {results}")
66
+ return results
67
+ except requests.exceptions.RequestException as e:
68
+ logger.error(f"Error occurred during request: {str(e)}")
69
+ raise
70
+ except IOError as e:
71
+ logger.error(f"Error processing file: {str(e)}")
72
+ raise
73
+ finally:
74
+ if 'temp_file_path' in locals():
75
+ os.unlink(temp_file_path)
76
+
77
+ def plot_detections(image_path: str, detections: List[Dict]):
78
+ img = cv2.imread(image_path)
79
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
80
+ img_draw = img.copy()
81
+
82
+ for det in detections:
83
+ x, y, w, h = map(int, det['bbox'])
84
+ cv2.rectangle(img_draw, (x, y), (w, h), (255, 0, 0), 2)
85
+
86
+ label = f"{det['class']} {det['confidence']:.2f}"
87
+ (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
88
+ cv2.rectangle(img_draw, (x, y - text_height - 5), (x + text_width, y), (255, 0, 0), -1)
89
+ cv2.putText(img_draw, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
90
+
91
+ plt.figure(figsize=(12, 9))
92
+ plt.imshow(img_draw)
93
+ plt.axis('off')
94
+ plt.tight_layout()
95
+ plt.show()
96
+
97
+ def process_image(image, conf_thres, iou_thres):
98
+ api_url = os.getenv('YOLO_API_URL', 'https://akc8p0uym6ktml-8000.proxy.runpod.net/detect')
99
+ detector = YOLODetector(api_url)
100
+
101
+ with NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
102
+ image.save(temp_file, format='PNG')
103
+ temp_file_path = temp_file.name
104
+
105
+ processed_image_path = detector.preprocess_image(temp_file_path)
106
+ results = detector.detect(processed_image_path, conf_thres=conf_thres, iou_thres=iou_thres)
107
+
108
+ img = Image.open(processed_image_path)
109
+ img_draw = img.copy()
110
+ draw = ImageDraw.Draw(img_draw)
111
+
112
+ for det in results['detections']:
113
+ x, y, w, h = map(int, det['bbox'])
114
+ draw.rectangle([x, y, w, h], outline="red", width=2)
115
+ label = f"{det['class']} {det['confidence']:.2f}"
116
+ draw.text((x, y - 10), label, fill="red")
117
+
118
+ os.unlink(temp_file_path)
119
+ os.unlink(processed_image_path)
120
+
121
+ return img_draw
122
+
123
+ def main():
124
+ with gr.Blocks() as demo:
125
+ gr.Markdown("# YOLO Object Detection")
126
+ gr.Markdown("Upload an image and adjust the confidence and IOU thresholds to detect objects.")
127
+
128
+ with gr.Row():
129
+ with gr.Column():
130
+ image_input = gr.Image(type="pil", label="Input Image")
131
+ conf_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.25, label="Confidence Threshold")
132
+ iou_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.45, label="IOU Threshold")
133
+ with gr.Column():
134
+ image_output = gr.Image(type="pil", label="Detection Result")
135
+ def update_output(image, conf_thres, iou_thres):
136
+ if image is None:
137
+ return None
138
+ return process_image(image, conf_thres, iou_thres)
139
+
140
+ inputs = [image_input, conf_slider, iou_slider]
141
+ outputs = image_output
142
+
143
+ image_input.change(fn=update_output, inputs=inputs, outputs=outputs)
144
+ conf_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
145
+ iou_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
146
+
147
+ demo.launch()
148
+
149
+ if __name__ == "__main__":
150
+ main()