File size: 6,257 Bytes
921c700 25fa20a 921c700 25fa20a 921c700 632fb40 921c700 632fb40 921c700 25fa20a 921c700 25fa20a a22d803 921c700 25fa20a 921c700 25fa20a 921c700 25fa20a 921c700 25fa20a 921c700 25fa20a 921c700 25fa20a 921c700 25fa20a 921c700 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
import cv2
import numpy as np
import requests
import torch
import math
from PIL import Image
import io
from typing import List, Dict, Union, Tuple
from tempfile import NamedTemporaryFile
import logging
import matplotlib.pyplot as plt
import gradio as gr
from PIL import ImageDraw
from enum import Enum
# Load environment variables
# load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EnumModel(str, Enum):
lion90 = "lion90"
dulcet72 = "dulcet72"
default = "default"
class YOLODetector:
def __init__(self, api_url: str, stride: int = 3):
self.api_url = api_url
self.stride = stride
def preprocess_image(self, image_path: str, target_size: Tuple[int, int] = (1280, 1280)) -> str:
img = cv2.imread(image_path)
img_size = self.check_img_size(target_size, s=self.stride)
img_resized = cv2.resize(img, img_size)
resized_image_path = '/tmp/resized_image.png'
cv2.imwrite(resized_image_path, img_resized)
return resized_image_path
@staticmethod
def make_divisible(x: int, divisor: int) -> int:
return math.ceil(x / divisor) * divisor
def check_img_size(self, imgsz: Union[int, List[int]], s: int = 32, floor: int = 0) -> Union[int, List[int]]:
if isinstance(imgsz, int):
new_size = max(self.make_divisible(imgsz, int(s)), floor)
else:
imgsz = list(imgsz)
new_size = [max(self.make_divisible(x, int(s)), floor) for x in imgsz]
if new_size != imgsz:
logger.warning(f'WARNING: Image size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
return new_size
def detect(self, image_path: str, conf_thres: float = 0.25, iou_thres: float = 0.45, model: EnumModel = EnumModel.dulcet72) -> Dict:
params = {'conf_thres': conf_thres, 'iou_thres': iou_thres, 'model': model.value}
try:
with open(image_path, 'rb') as image_file:
with NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
temp_file.write(image_file.read())
temp_file_path = temp_file.name
with open(temp_file_path, 'rb') as f:
files = {'file': ('image.png', f, 'image/png')}
response = requests.post(self.api_url, params=params, files=files, timeout=30)
response.raise_for_status()
results = response.json()
logger.info(f"Detection results: {results}")
return results
except requests.exceptions.RequestException as e:
logger.error(f"Error occurred during request: {str(e)}")
raise
except IOError as e:
logger.error(f"Error processing file: {str(e)}")
raise
finally:
if 'temp_file_path' in locals():
os.unlink(temp_file_path)
def plot_detections(image_path: str, detections: List[Dict]):
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_draw = img.copy()
for det in detections:
x, y, w, h = map(int, det['bbox'])
cv2.rectangle(img_draw, (x, y), (w, h), (255, 0, 0), 2)
label = f"{det['class']} {det['confidence']:.2f}"
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(img_draw, (x, y - text_height - 5), (x + text_width, y), (255, 0, 0), -1)
cv2.putText(img_draw, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
plt.figure(figsize=(12, 9))
plt.imshow(img_draw)
plt.axis('off')
plt.tight_layout()
plt.show()
def process_image(image, conf_thres, iou_thres, model):
api_url = os.getenv('YOLO_API_URL', 'https://o3f0aia4gtlf0l-8888.proxy.runpod.net/detect/')
detector = YOLODetector(api_url)
with NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
image.save(temp_file, format='PNG')
temp_file_path = temp_file.name
processed_image_path = detector.preprocess_image(temp_file_path)
results = detector.detect(processed_image_path, conf_thres=conf_thres, iou_thres=iou_thres, model=model)
img = Image.open(processed_image_path)
img_draw = img.copy()
draw = ImageDraw.Draw(img_draw)
for det in results['detections']:
x, y, w, h = map(int, det['bbox'])
draw.rectangle([x, y, w, h], outline="red", width=2)
label = f"{det['class']} {det['confidence']:.2f}"
draw.text((x, y - 10), label, fill="red")
os.unlink(temp_file_path)
os.unlink(processed_image_path)
return img_draw
def main():
with gr.Blocks() as demo:
gr.Markdown("# YOLO Object Detection")
gr.Markdown("Upload an image, select a model, and adjust the confidence and IOU thresholds to detect objects.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image")
model_dropdown = gr.Dropdown(choices=[model.value for model in EnumModel], value=EnumModel.dulcet72.value, label="Model")
conf_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.25, label="Confidence Threshold")
iou_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.45, label="IOU Threshold")
with gr.Column():
image_output = gr.Image(type="pil", label="Detection Result")
def update_output(image, model, conf_thres, iou_thres):
if image is None:
return None
return process_image(image, conf_thres, iou_thres, EnumModel(model))
inputs = [image_input, model_dropdown, conf_slider, iou_slider]
outputs = image_output
image_input.change(fn=update_output, inputs=inputs, outputs=outputs)
model_dropdown.change(fn=update_output, inputs=inputs, outputs=outputs)
conf_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
iou_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
demo.launch()
if __name__ == "__main__":
main()
|