Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import requests
|
5 |
+
import torch
|
6 |
+
import math
|
7 |
+
from PIL import Image
|
8 |
+
import io
|
9 |
+
from typing import List, Dict, Union, Tuple
|
10 |
+
from tempfile import NamedTemporaryFile
|
11 |
+
import logging
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import gradio as gr
|
14 |
+
from PIL import ImageDraw
|
15 |
+
|
16 |
+
# Load environment variables
|
17 |
+
# load_dotenv()
|
18 |
+
|
19 |
+
# Configure logging
|
20 |
+
logging.basicConfig(level=logging.INFO)
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
class YOLODetector:
|
24 |
+
def __init__(self, api_url: str, stride: int = 32):
|
25 |
+
self.api_url = api_url
|
26 |
+
self.stride = stride
|
27 |
+
|
28 |
+
def preprocess_image(self, image_path: str, target_size: Tuple[int, int] = (640, 640)) -> str:
|
29 |
+
img = cv2.imread(image_path)
|
30 |
+
img_size = self.check_img_size(target_size, s=self.stride)
|
31 |
+
img_resized = cv2.resize(img, img_size)
|
32 |
+
resized_image_path = '/tmp/resized_image.png'
|
33 |
+
cv2.imwrite(resized_image_path, img_resized)
|
34 |
+
return resized_image_path
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def make_divisible(x: int, divisor: int) -> int:
|
38 |
+
return math.ceil(x / divisor) * divisor
|
39 |
+
|
40 |
+
def check_img_size(self, imgsz: Union[int, List[int]], s: int = 32, floor: int = 0) -> Union[int, List[int]]:
|
41 |
+
if isinstance(imgsz, int):
|
42 |
+
new_size = max(self.make_divisible(imgsz, int(s)), floor)
|
43 |
+
else:
|
44 |
+
imgsz = list(imgsz)
|
45 |
+
new_size = [max(self.make_divisible(x, int(s)), floor) for x in imgsz]
|
46 |
+
if new_size != imgsz:
|
47 |
+
logger.warning(f'WARNING: Image size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
|
48 |
+
return new_size
|
49 |
+
|
50 |
+
def detect(self, image_path: str, conf_thres: float = 0.25, iou_thres: float = 0.45) -> Dict:
|
51 |
+
params = {'conf_thres': conf_thres, 'iou_thres': iou_thres}
|
52 |
+
|
53 |
+
try:
|
54 |
+
with open(image_path, 'rb') as image_file:
|
55 |
+
with NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
|
56 |
+
temp_file.write(image_file.read())
|
57 |
+
temp_file_path = temp_file.name
|
58 |
+
|
59 |
+
with open(temp_file_path, 'rb') as f:
|
60 |
+
files = {'file': ('image.png', f, 'image/png')}
|
61 |
+
response = requests.post(self.api_url, params=params, files=files, timeout=30)
|
62 |
+
|
63 |
+
response.raise_for_status()
|
64 |
+
results = response.json()
|
65 |
+
logger.info(f"Detection results: {results}")
|
66 |
+
return results
|
67 |
+
except requests.exceptions.RequestException as e:
|
68 |
+
logger.error(f"Error occurred during request: {str(e)}")
|
69 |
+
raise
|
70 |
+
except IOError as e:
|
71 |
+
logger.error(f"Error processing file: {str(e)}")
|
72 |
+
raise
|
73 |
+
finally:
|
74 |
+
if 'temp_file_path' in locals():
|
75 |
+
os.unlink(temp_file_path)
|
76 |
+
|
77 |
+
def plot_detections(image_path: str, detections: List[Dict]):
|
78 |
+
img = cv2.imread(image_path)
|
79 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
80 |
+
img_draw = img.copy()
|
81 |
+
|
82 |
+
for det in detections:
|
83 |
+
x, y, w, h = map(int, det['bbox'])
|
84 |
+
cv2.rectangle(img_draw, (x, y), (w, h), (255, 0, 0), 2)
|
85 |
+
|
86 |
+
label = f"{det['class']} {det['confidence']:.2f}"
|
87 |
+
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
88 |
+
cv2.rectangle(img_draw, (x, y - text_height - 5), (x + text_width, y), (255, 0, 0), -1)
|
89 |
+
cv2.putText(img_draw, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
|
90 |
+
|
91 |
+
plt.figure(figsize=(12, 9))
|
92 |
+
plt.imshow(img_draw)
|
93 |
+
plt.axis('off')
|
94 |
+
plt.tight_layout()
|
95 |
+
plt.show()
|
96 |
+
|
97 |
+
def process_image(image, conf_thres, iou_thres):
|
98 |
+
api_url = os.getenv('YOLO_API_URL', 'https://akc8p0uym6ktml-8000.proxy.runpod.net/detect')
|
99 |
+
detector = YOLODetector(api_url)
|
100 |
+
|
101 |
+
with NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
|
102 |
+
image.save(temp_file, format='PNG')
|
103 |
+
temp_file_path = temp_file.name
|
104 |
+
|
105 |
+
processed_image_path = detector.preprocess_image(temp_file_path)
|
106 |
+
results = detector.detect(processed_image_path, conf_thres=conf_thres, iou_thres=iou_thres)
|
107 |
+
|
108 |
+
img = Image.open(processed_image_path)
|
109 |
+
img_draw = img.copy()
|
110 |
+
draw = ImageDraw.Draw(img_draw)
|
111 |
+
|
112 |
+
for det in results['detections']:
|
113 |
+
x, y, w, h = map(int, det['bbox'])
|
114 |
+
draw.rectangle([x, y, w, h], outline="red", width=2)
|
115 |
+
label = f"{det['class']} {det['confidence']:.2f}"
|
116 |
+
draw.text((x, y - 10), label, fill="red")
|
117 |
+
|
118 |
+
os.unlink(temp_file_path)
|
119 |
+
os.unlink(processed_image_path)
|
120 |
+
|
121 |
+
return img_draw
|
122 |
+
|
123 |
+
def main():
|
124 |
+
with gr.Blocks() as demo:
|
125 |
+
gr.Markdown("# YOLO Object Detection")
|
126 |
+
gr.Markdown("Upload an image and adjust the confidence and IOU thresholds to detect objects.")
|
127 |
+
|
128 |
+
with gr.Row():
|
129 |
+
with gr.Column():
|
130 |
+
image_input = gr.Image(type="pil", label="Input Image")
|
131 |
+
conf_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.25, label="Confidence Threshold")
|
132 |
+
iou_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.45, label="IOU Threshold")
|
133 |
+
with gr.Column():
|
134 |
+
image_output = gr.Image(type="pil", label="Detection Result")
|
135 |
+
def update_output(image, conf_thres, iou_thres):
|
136 |
+
if image is None:
|
137 |
+
return None
|
138 |
+
return process_image(image, conf_thres, iou_thres)
|
139 |
+
|
140 |
+
inputs = [image_input, conf_slider, iou_slider]
|
141 |
+
outputs = image_output
|
142 |
+
|
143 |
+
image_input.change(fn=update_output, inputs=inputs, outputs=outputs)
|
144 |
+
conf_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
|
145 |
+
iou_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
|
146 |
+
|
147 |
+
demo.launch()
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
main()
|