Robertooo commited on
Commit
25fa20a
·
verified ·
1 Parent(s): 55e8596

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -12,6 +12,7 @@ import logging
12
  import matplotlib.pyplot as plt
13
  import gradio as gr
14
  from PIL import ImageDraw
 
15
 
16
  # Load environment variables
17
  # load_dotenv()
@@ -20,6 +21,11 @@ from PIL import ImageDraw
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
 
 
 
 
 
23
  class YOLODetector:
24
  def __init__(self, api_url: str, stride: int = 32):
25
  self.api_url = api_url
@@ -47,8 +53,8 @@ class YOLODetector:
47
  logger.warning(f'WARNING: Image size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
48
  return new_size
49
 
50
- def detect(self, image_path: str, conf_thres: float = 0.25, iou_thres: float = 0.45) -> Dict:
51
- params = {'conf_thres': conf_thres, 'iou_thres': iou_thres}
52
 
53
  try:
54
  with open(image_path, 'rb') as image_file:
@@ -94,7 +100,7 @@ def plot_detections(image_path: str, detections: List[Dict]):
94
  plt.tight_layout()
95
  plt.show()
96
 
97
- def process_image(image, conf_thres, iou_thres):
98
  api_url = os.getenv('YOLO_API_URL', 'https://akc8p0uym6ktml-8000.proxy.runpod.net/detect')
99
  detector = YOLODetector(api_url)
100
 
@@ -103,7 +109,7 @@ def process_image(image, conf_thres, iou_thres):
103
  temp_file_path = temp_file.name
104
 
105
  processed_image_path = detector.preprocess_image(temp_file_path)
106
- results = detector.detect(processed_image_path, conf_thres=conf_thres, iou_thres=iou_thres)
107
 
108
  img = Image.open(processed_image_path)
109
  img_draw = img.copy()
@@ -123,24 +129,27 @@ def process_image(image, conf_thres, iou_thres):
123
  def main():
124
  with gr.Blocks() as demo:
125
  gr.Markdown("# YOLO Object Detection")
126
- gr.Markdown("Upload an image and adjust the confidence and IOU thresholds to detect objects.")
127
 
128
  with gr.Row():
129
  with gr.Column():
130
  image_input = gr.Image(type="pil", label="Input Image")
 
131
  conf_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.25, label="Confidence Threshold")
132
  iou_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.45, label="IOU Threshold")
133
  with gr.Column():
134
  image_output = gr.Image(type="pil", label="Detection Result")
135
- def update_output(image, conf_thres, iou_thres):
 
136
  if image is None:
137
  return None
138
- return process_image(image, conf_thres, iou_thres)
139
 
140
- inputs = [image_input, conf_slider, iou_slider]
141
  outputs = image_output
142
 
143
  image_input.change(fn=update_output, inputs=inputs, outputs=outputs)
 
144
  conf_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
145
  iou_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
146
 
 
12
  import matplotlib.pyplot as plt
13
  import gradio as gr
14
  from PIL import ImageDraw
15
+ from enum import Enum
16
 
17
  # Load environment variables
18
  # load_dotenv()
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+ class EnumModel(str, Enum):
25
+ lion90 = "lion90"
26
+ dulcet72 = "dulcet72"
27
+ default = "default"
28
+
29
  class YOLODetector:
30
  def __init__(self, api_url: str, stride: int = 32):
31
  self.api_url = api_url
 
53
  logger.warning(f'WARNING: Image size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
54
  return new_size
55
 
56
+ def detect(self, image_path: str, conf_thres: float = 0.25, iou_thres: float = 0.45, model: EnumModel = EnumModel.dulcet72) -> Dict:
57
+ params = {'conf_thres': conf_thres, 'iou_thres': iou_thres, 'model': model.value}
58
 
59
  try:
60
  with open(image_path, 'rb') as image_file:
 
100
  plt.tight_layout()
101
  plt.show()
102
 
103
+ def process_image(image, conf_thres, iou_thres, model):
104
  api_url = os.getenv('YOLO_API_URL', 'https://akc8p0uym6ktml-8000.proxy.runpod.net/detect')
105
  detector = YOLODetector(api_url)
106
 
 
109
  temp_file_path = temp_file.name
110
 
111
  processed_image_path = detector.preprocess_image(temp_file_path)
112
+ results = detector.detect(processed_image_path, conf_thres=conf_thres, iou_thres=iou_thres, model=model)
113
 
114
  img = Image.open(processed_image_path)
115
  img_draw = img.copy()
 
129
  def main():
130
  with gr.Blocks() as demo:
131
  gr.Markdown("# YOLO Object Detection")
132
+ gr.Markdown("Upload an image, select a model, and adjust the confidence and IOU thresholds to detect objects.")
133
 
134
  with gr.Row():
135
  with gr.Column():
136
  image_input = gr.Image(type="pil", label="Input Image")
137
+ model_dropdown = gr.Dropdown(choices=[model.value for model in EnumModel], value=EnumModel.dulcet72.value, label="Model")
138
  conf_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.25, label="Confidence Threshold")
139
  iou_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.45, label="IOU Threshold")
140
  with gr.Column():
141
  image_output = gr.Image(type="pil", label="Detection Result")
142
+
143
+ def update_output(image, model, conf_thres, iou_thres):
144
  if image is None:
145
  return None
146
+ return process_image(image, conf_thres, iou_thres, EnumModel(model))
147
 
148
+ inputs = [image_input, model_dropdown, conf_slider, iou_slider]
149
  outputs = image_output
150
 
151
  image_input.change(fn=update_output, inputs=inputs, outputs=outputs)
152
+ model_dropdown.change(fn=update_output, inputs=inputs, outputs=outputs)
153
  conf_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
154
  iou_slider.change(fn=update_output, inputs=inputs, outputs=outputs)
155