visionGod / app.py
Robertooo's picture
Update app.py
a22d803 verified
import os
import cv2
import numpy as np
import requests
import torch
import math
from PIL import Image
import io
from typing import List, Dict, Union, Tuple
from tempfile import NamedTemporaryFile
import logging
import matplotlib.pyplot as plt
import gradio as gr
from PIL import ImageDraw
from enum import Enum
# Load environment variables
# load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EnumModel(str, Enum):
lion90 = "lion90"
dulcet72 = "dulcet72"
default = "default"
class YOLODetector:
def __init__(self, api_url: str, stride: int = 3):
self.api_url = api_url
self.stride = stride
def preprocess_image(self, image_path: str, target_size: Tuple[int, int] = (1280, 1280)) -> str:
img = cv2.imread(image_path)
img_size = self.check_img_size(target_size, s=self.stride)
img_resized = cv2.resize(img, img_size)
resized_image_path = '/tmp/resized_image.png'
cv2.imwrite(resized_image_path, img_resized)
return resized_image_path
@staticmethod
def make_divisible(x: int, divisor: int) -> int:
return math.ceil(x / divisor) * divisor
def check_img_size(self, imgsz: Union[int, List[int]], s: int = 32, floor: int = 0) -> Union[int, List[int]]:
if isinstance(imgsz, int):
new_size = max(self.make_divisible(imgsz, int(s)), floor)
else:
imgsz = list(imgsz)
new_size = [max(self.make_divisible(x, int(s)), floor) for x in imgsz]
if new_size != imgsz:
logger.warning(f'WARNING: Image size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
return new_size
def detect(self, image_path: str, conf_thres: float = 0.25, iou_thres: float = 0.45, model: EnumModel = EnumModel.dulcet72) -> Dict:
params = {'conf_thres': conf_thres, 'iou_thres': iou_thres, 'model': model.value}
try:
with open(image_path, 'rb') as image_file:
with NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
temp_file.write(image_file.read())
temp_file_path = temp_file.name
with open(temp_file_path, 'rb') as f:
files = {'file': ('image.png', f, 'image/png')}
response = requests.post(self.api_url, params=params, files=files, timeout=30)
response.raise_for_status()
results = response.json()
logger.info(f"Detection results: {results}")
return results
except requests.exceptions.RequestException as e:
logger.error(f"Error occurred during request: {str(e)}")
raise
except IOError as e:
logger.error(f"Error processing file: {str(e)}")
raise
finally:
if 'temp_file_path' in locals():
os.unlink(temp_file_path)
def plot_detections(image_path: str, detections: List[Dict]):
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_draw = img.copy()
for det in detections:
x, y, w, h = map(int, det['bbox'])
cv2.rectangle(img_draw, (x, y), (w, h), (255, 0, 0), 2)
label = f"{det['class']} {det['confidence']:.2f}"
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(img_draw, (x, y - text_height - 5), (x + text_width, y), (255, 0, 0), -1)
cv2.putText(img_draw, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
plt.figure(figsize=(12, 9))
plt.imshow(img_draw)
plt.axis('off')
plt.tight_layout()
plt.show()
def process_image(image, conf_thres, iou_thres, model):
api_url = os.getenv('YOLO_API_URL', 'https://o3f0aia4gtlf0l-8888.proxy.runpod.net/detect/')
detector = YOLODetector(api_url)
with NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
image.save(temp_file, format='PNG')
temp_file_path = temp_file.name
processed_image_path = detector.preprocess_image(temp_file_path)
results = detector.detect(processed_image_path, conf_thres=conf_thres, iou_thres=iou_thres, model=model)
img = Image.open(processed_image_path)
img_draw = img.copy()
draw = ImageDraw.Draw(img_draw)
for det in results['detections']:
x, y, w, h = map(int, det['bbox'])
draw.rectangle([x, y, w, h], outline="red", width=2)
label = f"{det['class']} {det['confidence']:.2f}"
draw.text((x, y - 10), label, fill="red")
os.unlink(temp_file_path)
os.unlink(processed_image_path)
return img_draw
def main():
with gr.Blocks() as demo:
gr.Markdown("# YOLO Object Detection")
gr.Markdown("Upload an image, select a model, and adjust the confidence and IOU thresholds to detect objects.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image")
model_dropdown = gr.Dropdown(choices=[model.value for model in EnumModel], value=EnumModel.dulcet72.value, label="Model")
conf_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.25, label="Confidence Threshold")
iou_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.45, label="IOU Threshold")
with gr.Column():
image_output = gr.Image(type="pil", label="Detection Result")
def update_output(image, model, conf_thres, iou_thres):
if image is None:
return None
return process_image(image, conf_thres, iou_thres, EnumModel(model))
inputs = [image_input, model_dropdown, conf_slider, iou_slider]
outputs = image_output
image_input.change(fn=update_output, inputs=inputs, outputs=outputs)
model_dropdown.change(fn=update_output, inputs=inputs, outputs=outputs)
conf_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
iou_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
demo.launch()
if __name__ == "__main__":
main()