|
import os |
|
import cv2 |
|
import numpy as np |
|
import requests |
|
import torch |
|
import math |
|
from PIL import Image |
|
import io |
|
from typing import List, Dict, Union, Tuple |
|
from tempfile import NamedTemporaryFile |
|
import logging |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
from PIL import ImageDraw |
|
from enum import Enum |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class EnumModel(str, Enum): |
|
lion90 = "lion90" |
|
dulcet72 = "dulcet72" |
|
default = "default" |
|
|
|
class YOLODetector: |
|
def __init__(self, api_url: str, stride: int = 3): |
|
self.api_url = api_url |
|
self.stride = stride |
|
|
|
def preprocess_image(self, image_path: str, target_size: Tuple[int, int] = (1280, 1280)) -> str: |
|
img = cv2.imread(image_path) |
|
img_size = self.check_img_size(target_size, s=self.stride) |
|
img_resized = cv2.resize(img, img_size) |
|
resized_image_path = '/tmp/resized_image.png' |
|
cv2.imwrite(resized_image_path, img_resized) |
|
return resized_image_path |
|
|
|
@staticmethod |
|
def make_divisible(x: int, divisor: int) -> int: |
|
return math.ceil(x / divisor) * divisor |
|
|
|
def check_img_size(self, imgsz: Union[int, List[int]], s: int = 32, floor: int = 0) -> Union[int, List[int]]: |
|
if isinstance(imgsz, int): |
|
new_size = max(self.make_divisible(imgsz, int(s)), floor) |
|
else: |
|
imgsz = list(imgsz) |
|
new_size = [max(self.make_divisible(x, int(s)), floor) for x in imgsz] |
|
if new_size != imgsz: |
|
logger.warning(f'WARNING: Image size {imgsz} must be multiple of max stride {s}, updating to {new_size}') |
|
return new_size |
|
|
|
def detect(self, image_path: str, conf_thres: float = 0.25, iou_thres: float = 0.45, model: EnumModel = EnumModel.dulcet72) -> Dict: |
|
params = {'conf_thres': conf_thres, 'iou_thres': iou_thres, 'model': model.value} |
|
|
|
try: |
|
with open(image_path, 'rb') as image_file: |
|
with NamedTemporaryFile(delete=False, suffix='.png') as temp_file: |
|
temp_file.write(image_file.read()) |
|
temp_file_path = temp_file.name |
|
|
|
with open(temp_file_path, 'rb') as f: |
|
files = {'file': ('image.png', f, 'image/png')} |
|
response = requests.post(self.api_url, params=params, files=files, timeout=30) |
|
|
|
response.raise_for_status() |
|
results = response.json() |
|
logger.info(f"Detection results: {results}") |
|
return results |
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Error occurred during request: {str(e)}") |
|
raise |
|
except IOError as e: |
|
logger.error(f"Error processing file: {str(e)}") |
|
raise |
|
finally: |
|
if 'temp_file_path' in locals(): |
|
os.unlink(temp_file_path) |
|
|
|
def plot_detections(image_path: str, detections: List[Dict]): |
|
img = cv2.imread(image_path) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
img_draw = img.copy() |
|
|
|
for det in detections: |
|
x, y, w, h = map(int, det['bbox']) |
|
cv2.rectangle(img_draw, (x, y), (w, h), (255, 0, 0), 2) |
|
|
|
label = f"{det['class']} {det['confidence']:.2f}" |
|
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) |
|
cv2.rectangle(img_draw, (x, y - text_height - 5), (x + text_width, y), (255, 0, 0), -1) |
|
cv2.putText(img_draw, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA) |
|
|
|
plt.figure(figsize=(12, 9)) |
|
plt.imshow(img_draw) |
|
plt.axis('off') |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
def process_image(image, conf_thres, iou_thres, model): |
|
api_url = os.getenv('YOLO_API_URL', 'https://o3f0aia4gtlf0l-8888.proxy.runpod.net/detect/') |
|
detector = YOLODetector(api_url) |
|
|
|
with NamedTemporaryFile(delete=False, suffix='.png') as temp_file: |
|
image.save(temp_file, format='PNG') |
|
temp_file_path = temp_file.name |
|
|
|
processed_image_path = detector.preprocess_image(temp_file_path) |
|
results = detector.detect(processed_image_path, conf_thres=conf_thres, iou_thres=iou_thres, model=model) |
|
|
|
img = Image.open(processed_image_path) |
|
img_draw = img.copy() |
|
draw = ImageDraw.Draw(img_draw) |
|
|
|
for det in results['detections']: |
|
x, y, w, h = map(int, det['bbox']) |
|
draw.rectangle([x, y, w, h], outline="red", width=2) |
|
label = f"{det['class']} {det['confidence']:.2f}" |
|
draw.text((x, y - 10), label, fill="red") |
|
|
|
os.unlink(temp_file_path) |
|
os.unlink(processed_image_path) |
|
|
|
return img_draw |
|
|
|
def main(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# YOLO Object Detection") |
|
gr.Markdown("Upload an image, select a model, and adjust the confidence and IOU thresholds to detect objects.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil", label="Input Image") |
|
model_dropdown = gr.Dropdown(choices=[model.value for model in EnumModel], value=EnumModel.dulcet72.value, label="Model") |
|
conf_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.25, label="Confidence Threshold") |
|
iou_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.45, label="IOU Threshold") |
|
with gr.Column(): |
|
image_output = gr.Image(type="pil", label="Detection Result") |
|
|
|
def update_output(image, model, conf_thres, iou_thres): |
|
if image is None: |
|
return None |
|
return process_image(image, conf_thres, iou_thres, EnumModel(model)) |
|
|
|
inputs = [image_input, model_dropdown, conf_slider, iou_slider] |
|
outputs = image_output |
|
|
|
image_input.change(fn=update_output, inputs=inputs, outputs=outputs) |
|
model_dropdown.change(fn=update_output, inputs=inputs, outputs=outputs) |
|
conf_slider.change(fn=update_output, inputs=inputs, outputs=outputs) |
|
iou_slider.change(fn=update_output, inputs=inputs, outputs=outputs) |
|
|
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|