jonathanagustin commited on
Commit
c4cb2d4
1 Parent(s): 97ed94f

simplify annotations

Browse files
Files changed (1) hide show
  1. app.py +4 -20
app.py CHANGED
@@ -215,7 +215,6 @@ class LiveYouTubeObjectDetector:
215
  detect_objects: Detects objects in a live YouTube stream given its URL.
216
  get_frame: Captures a frame from a live stream URL.
217
  annotate: Annotates a frame with detected objects.
218
- get_annotations: Converts YOLO detection results into annotations for Gradio.
219
  create_black_image: Creates a black placeholder image.
220
  get_live_streams: Searches for live streams based on a query.
221
  render: Sets up and launches the Gradio interface.
@@ -281,7 +280,7 @@ class LiveYouTubeObjectDetector:
281
 
282
  def annotate(self, frame: np.ndarray) -> Tuple[Image.Image, List[Tuple[Tuple[int, int, int, int], str]]]:
283
  """
284
- Annotates the given frame with detected objects.
285
 
286
  :param frame: The frame to be annotated.
287
  :type frame: np.ndarray
@@ -290,22 +289,6 @@ class LiveYouTubeObjectDetector:
290
  """
291
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
292
  results = self.model(frame_rgb)
293
- annotations = self.get_annotations(results)
294
- return Image.fromarray(frame_rgb), annotations
295
-
296
- def get_annotations(self, results) -> List[Tuple[Tuple[int, int, int, int], str]]:
297
- """
298
- Converts YOLO detection results into annotations suitable for Gradio visualization.
299
-
300
- This method processes the results from the YOLO object detection model, extracting
301
- the bounding box coordinates and class names for each detected object.
302
-
303
- :param results: The detection results from the YOLO model.
304
- :type results: List[DetectionResult]
305
- :return: A list of tuples, each containing the bounding box (as a tuple of integers)
306
- and the class name of the detected object.
307
- :rtype: List[Tuple[Tuple[int, int, int, int], str]]
308
- """
309
  annotations = []
310
  for result in results:
311
  for box in result.boxes:
@@ -315,7 +298,8 @@ class LiveYouTubeObjectDetector:
315
  x1, y1, x2, y2 = box.xyxy[0]
316
  bbox_coords = (int(x1), int(y1), int(x2), int(y2))
317
  annotations.append((bbox_coords, class_name))
318
- return annotations
 
319
 
320
  @staticmethod
321
  def create_black_image():
@@ -374,7 +358,7 @@ class LiveYouTubeObjectDetector:
374
  with gr.Row():
375
  self.gallery.render()
376
 
377
- @self.gallery.select(inputs=None, outputs=[self.annotated_image, self.stream_input])
378
  def detect_objects_from_gallery_item(evt: gr.SelectData):
379
  if evt.index is not None and evt.index < len(self.streams):
380
  selected_stream = self.streams[evt.index]
 
215
  detect_objects: Detects objects in a live YouTube stream given its URL.
216
  get_frame: Captures a frame from a live stream URL.
217
  annotate: Annotates a frame with detected objects.
 
218
  create_black_image: Creates a black placeholder image.
219
  get_live_streams: Searches for live streams based on a query.
220
  render: Sets up and launches the Gradio interface.
 
280
 
281
  def annotate(self, frame: np.ndarray) -> Tuple[Image.Image, List[Tuple[Tuple[int, int, int, int], str]]]:
282
  """
283
+ Annotates the given frame with detected objects and their bounding boxes.
284
 
285
  :param frame: The frame to be annotated.
286
  :type frame: np.ndarray
 
289
  """
290
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
291
  results = self.model(frame_rgb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  annotations = []
293
  for result in results:
294
  for box in result.boxes:
 
298
  x1, y1, x2, y2 = box.xyxy[0]
299
  bbox_coords = (int(x1), int(y1), int(x2), int(y2))
300
  annotations.append((bbox_coords, class_name))
301
+
302
+ return Image.fromarray(frame_rgb), annotations
303
 
304
  @staticmethod
305
  def create_black_image():
 
358
  with gr.Row():
359
  self.gallery.render()
360
 
361
+ @self.gallery.select(inputs=None, outputs=[self.annotated_image, self.stream_input], scroll_to_output=True)
362
  def detect_objects_from_gallery_item(evt: gr.SelectData):
363
  if evt.index is not None and evt.index < len(self.streams):
364
  selected_stream = self.streams[evt.index]