Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ import logging
|
|
12 |
import matplotlib.pyplot as plt
|
13 |
import gradio as gr
|
14 |
from PIL import ImageDraw
|
|
|
15 |
|
16 |
# Load environment variables
|
17 |
# load_dotenv()
|
@@ -20,6 +21,11 @@ from PIL import ImageDraw
|
|
20 |
logging.basicConfig(level=logging.INFO)
|
21 |
logger = logging.getLogger(__name__)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
23 |
class YOLODetector:
|
24 |
def __init__(self, api_url: str, stride: int = 32):
|
25 |
self.api_url = api_url
|
@@ -47,8 +53,8 @@ class YOLODetector:
|
|
47 |
logger.warning(f'WARNING: Image size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
|
48 |
return new_size
|
49 |
|
50 |
-
def detect(self, image_path: str, conf_thres: float = 0.25, iou_thres: float = 0.45) -> Dict:
|
51 |
-
params = {'conf_thres': conf_thres, 'iou_thres': iou_thres}
|
52 |
|
53 |
try:
|
54 |
with open(image_path, 'rb') as image_file:
|
@@ -94,7 +100,7 @@ def plot_detections(image_path: str, detections: List[Dict]):
|
|
94 |
plt.tight_layout()
|
95 |
plt.show()
|
96 |
|
97 |
-
def process_image(image, conf_thres, iou_thres):
|
98 |
api_url = os.getenv('YOLO_API_URL', 'https://akc8p0uym6ktml-8000.proxy.runpod.net/detect')
|
99 |
detector = YOLODetector(api_url)
|
100 |
|
@@ -103,7 +109,7 @@ def process_image(image, conf_thres, iou_thres):
|
|
103 |
temp_file_path = temp_file.name
|
104 |
|
105 |
processed_image_path = detector.preprocess_image(temp_file_path)
|
106 |
-
results = detector.detect(processed_image_path, conf_thres=conf_thres, iou_thres=iou_thres)
|
107 |
|
108 |
img = Image.open(processed_image_path)
|
109 |
img_draw = img.copy()
|
@@ -123,24 +129,27 @@ def process_image(image, conf_thres, iou_thres):
|
|
123 |
def main():
|
124 |
with gr.Blocks() as demo:
|
125 |
gr.Markdown("# YOLO Object Detection")
|
126 |
-
gr.Markdown("Upload an image and adjust the confidence and IOU thresholds to detect objects.")
|
127 |
|
128 |
with gr.Row():
|
129 |
with gr.Column():
|
130 |
image_input = gr.Image(type="pil", label="Input Image")
|
|
|
131 |
conf_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.25, label="Confidence Threshold")
|
132 |
iou_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.45, label="IOU Threshold")
|
133 |
with gr.Column():
|
134 |
image_output = gr.Image(type="pil", label="Detection Result")
|
135 |
-
|
|
|
136 |
if image is None:
|
137 |
return None
|
138 |
-
return process_image(image, conf_thres, iou_thres)
|
139 |
|
140 |
-
inputs = [image_input, conf_slider, iou_slider]
|
141 |
outputs = image_output
|
142 |
|
143 |
image_input.change(fn=update_output, inputs=inputs, outputs=outputs)
|
|
|
144 |
conf_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
|
145 |
iou_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
|
146 |
|
|
|
12 |
import matplotlib.pyplot as plt
|
13 |
import gradio as gr
|
14 |
from PIL import ImageDraw
|
15 |
+
from enum import Enum
|
16 |
|
17 |
# Load environment variables
|
18 |
# load_dotenv()
|
|
|
21 |
logging.basicConfig(level=logging.INFO)
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
+
class EnumModel(str, Enum):
|
25 |
+
lion90 = "lion90"
|
26 |
+
dulcet72 = "dulcet72"
|
27 |
+
default = "default"
|
28 |
+
|
29 |
class YOLODetector:
|
30 |
def __init__(self, api_url: str, stride: int = 32):
|
31 |
self.api_url = api_url
|
|
|
53 |
logger.warning(f'WARNING: Image size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
|
54 |
return new_size
|
55 |
|
56 |
+
def detect(self, image_path: str, conf_thres: float = 0.25, iou_thres: float = 0.45, model: EnumModel = EnumModel.dulcet72) -> Dict:
|
57 |
+
params = {'conf_thres': conf_thres, 'iou_thres': iou_thres, 'model': model.value}
|
58 |
|
59 |
try:
|
60 |
with open(image_path, 'rb') as image_file:
|
|
|
100 |
plt.tight_layout()
|
101 |
plt.show()
|
102 |
|
103 |
+
def process_image(image, conf_thres, iou_thres, model):
|
104 |
api_url = os.getenv('YOLO_API_URL', 'https://akc8p0uym6ktml-8000.proxy.runpod.net/detect')
|
105 |
detector = YOLODetector(api_url)
|
106 |
|
|
|
109 |
temp_file_path = temp_file.name
|
110 |
|
111 |
processed_image_path = detector.preprocess_image(temp_file_path)
|
112 |
+
results = detector.detect(processed_image_path, conf_thres=conf_thres, iou_thres=iou_thres, model=model)
|
113 |
|
114 |
img = Image.open(processed_image_path)
|
115 |
img_draw = img.copy()
|
|
|
129 |
def main():
|
130 |
with gr.Blocks() as demo:
|
131 |
gr.Markdown("# YOLO Object Detection")
|
132 |
+
gr.Markdown("Upload an image, select a model, and adjust the confidence and IOU thresholds to detect objects.")
|
133 |
|
134 |
with gr.Row():
|
135 |
with gr.Column():
|
136 |
image_input = gr.Image(type="pil", label="Input Image")
|
137 |
+
model_dropdown = gr.Dropdown(choices=[model.value for model in EnumModel], value=EnumModel.dulcet72.value, label="Model")
|
138 |
conf_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.25, label="Confidence Threshold")
|
139 |
iou_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.45, label="IOU Threshold")
|
140 |
with gr.Column():
|
141 |
image_output = gr.Image(type="pil", label="Detection Result")
|
142 |
+
|
143 |
+
def update_output(image, model, conf_thres, iou_thres):
|
144 |
if image is None:
|
145 |
return None
|
146 |
+
return process_image(image, conf_thres, iou_thres, EnumModel(model))
|
147 |
|
148 |
+
inputs = [image_input, model_dropdown, conf_slider, iou_slider]
|
149 |
outputs = image_output
|
150 |
|
151 |
image_input.change(fn=update_output, inputs=inputs, outputs=outputs)
|
152 |
+
model_dropdown.change(fn=update_output, inputs=inputs, outputs=outputs)
|
153 |
conf_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
|
154 |
iou_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
|
155 |
|