import os import cv2 import numpy as np import requests import torch import math from PIL import Image import io from typing import List, Dict, Union, Tuple from tempfile import NamedTemporaryFile import logging import matplotlib.pyplot as plt import gradio as gr from PIL import ImageDraw from enum import Enum # Load environment variables # load_dotenv() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EnumModel(str, Enum): lion90 = "lion90" dulcet72 = "dulcet72" default = "default" class YOLODetector: def __init__(self, api_url: str, stride: int = 3): self.api_url = api_url self.stride = stride def preprocess_image(self, image_path: str, target_size: Tuple[int, int] = (1280, 1280)) -> str: img = cv2.imread(image_path) img_size = self.check_img_size(target_size, s=self.stride) img_resized = cv2.resize(img, img_size) resized_image_path = '/tmp/resized_image.png' cv2.imwrite(resized_image_path, img_resized) return resized_image_path @staticmethod def make_divisible(x: int, divisor: int) -> int: return math.ceil(x / divisor) * divisor def check_img_size(self, imgsz: Union[int, List[int]], s: int = 32, floor: int = 0) -> Union[int, List[int]]: if isinstance(imgsz, int): new_size = max(self.make_divisible(imgsz, int(s)), floor) else: imgsz = list(imgsz) new_size = [max(self.make_divisible(x, int(s)), floor) for x in imgsz] if new_size != imgsz: logger.warning(f'WARNING: Image size {imgsz} must be multiple of max stride {s}, updating to {new_size}') return new_size def detect(self, image_path: str, conf_thres: float = 0.25, iou_thres: float = 0.45, model: EnumModel = EnumModel.dulcet72) -> Dict: params = {'conf_thres': conf_thres, 'iou_thres': iou_thres, 'model': model.value} try: with open(image_path, 'rb') as image_file: with NamedTemporaryFile(delete=False, suffix='.png') as temp_file: temp_file.write(image_file.read()) temp_file_path = temp_file.name with open(temp_file_path, 'rb') as f: files = {'file': ('image.png', f, 'image/png')} response = requests.post(self.api_url, params=params, files=files, timeout=30) response.raise_for_status() results = response.json() logger.info(f"Detection results: {results}") return results except requests.exceptions.RequestException as e: logger.error(f"Error occurred during request: {str(e)}") raise except IOError as e: logger.error(f"Error processing file: {str(e)}") raise finally: if 'temp_file_path' in locals(): os.unlink(temp_file_path) def plot_detections(image_path: str, detections: List[Dict]): img = cv2.imread(image_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_draw = img.copy() for det in detections: x, y, w, h = map(int, det['bbox']) cv2.rectangle(img_draw, (x, y), (w, h), (255, 0, 0), 2) label = f"{det['class']} {det['confidence']:.2f}" (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) cv2.rectangle(img_draw, (x, y - text_height - 5), (x + text_width, y), (255, 0, 0), -1) cv2.putText(img_draw, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA) plt.figure(figsize=(12, 9)) plt.imshow(img_draw) plt.axis('off') plt.tight_layout() plt.show() def process_image(image, conf_thres, iou_thres, model): api_url = os.getenv('YOLO_API_URL', 'https://o3f0aia4gtlf0l-8888.proxy.runpod.net/detect/') detector = YOLODetector(api_url) with NamedTemporaryFile(delete=False, suffix='.png') as temp_file: image.save(temp_file, format='PNG') temp_file_path = temp_file.name processed_image_path = detector.preprocess_image(temp_file_path) results = detector.detect(processed_image_path, conf_thres=conf_thres, iou_thres=iou_thres, model=model) img = Image.open(processed_image_path) img_draw = img.copy() draw = ImageDraw.Draw(img_draw) for det in results['detections']: x, y, w, h = map(int, det['bbox']) draw.rectangle([x, y, w, h], outline="red", width=2) label = f"{det['class']} {det['confidence']:.2f}" draw.text((x, y - 10), label, fill="red") os.unlink(temp_file_path) os.unlink(processed_image_path) return img_draw def main(): with gr.Blocks() as demo: gr.Markdown("# YOLO Object Detection") gr.Markdown("Upload an image, select a model, and adjust the confidence and IOU thresholds to detect objects.") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Input Image") model_dropdown = gr.Dropdown(choices=[model.value for model in EnumModel], value=EnumModel.dulcet72.value, label="Model") conf_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.25, label="Confidence Threshold") iou_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.45, label="IOU Threshold") with gr.Column(): image_output = gr.Image(type="pil", label="Detection Result") def update_output(image, model, conf_thres, iou_thres): if image is None: return None return process_image(image, conf_thres, iou_thres, EnumModel(model)) inputs = [image_input, model_dropdown, conf_slider, iou_slider] outputs = image_output image_input.change(fn=update_output, inputs=inputs, outputs=outputs) model_dropdown.change(fn=update_output, inputs=inputs, outputs=outputs) conf_slider.change(fn=update_output, inputs=inputs, outputs=outputs) iou_slider.change(fn=update_output, inputs=inputs, outputs=outputs) demo.launch() if __name__ == "__main__": main()