jonathanagustin commited on
Commit
87a479d
1 Parent(s): 933222d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -82
app.py CHANGED
@@ -1,15 +1,14 @@
1
  import logging
2
- import os
3
  import sys
4
- import zipfile
5
  from enum import Enum
6
- from typing import Any, Dict, List, Optional
 
7
  import cv2
8
  import gradio as gr
9
  import innertube
10
  import numpy as np
11
  import streamlink
12
- from PIL import Image, ImageDraw, ImageFont
13
  from ultralytics import YOLO
14
 
15
  logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
@@ -28,40 +27,34 @@ class SearchFilter(Enum):
28
  def __str__(self):
29
  return self.human_readable
30
 
 
31
  class SearchService:
32
  @staticmethod
33
- def search(
34
- query: Optional[str], filter: SearchFilter = SearchFilter.VIDEO
35
- ) -> (List[Dict[str, Any]], Optional[str]):
36
  client = innertube.InnerTube("WEB", "2.20230920.00.00")
37
  response = SearchService._search(query, filter)
38
  results = SearchService.parse(response)
39
  return results
40
 
41
  @staticmethod
42
- def parse(data: Dict[str, Any]) -> List[Dict[str, Any]]:
43
  results = []
44
  contents = data["contents"]["twoColumnSearchResultsRenderer"]["primaryContents"]["sectionListRenderer"]["contents"]
45
  items = contents[0]["itemSectionRenderer"]["contents"] if contents else []
46
  for item in items:
47
  if "videoRenderer" in item:
48
  renderer = item["videoRenderer"]
49
- video_id = renderer["videoId"]
50
- title = "".join(run["text"] for run in renderer["title"]["runs"])
51
- thumbnail_url = renderer["thumbnail"]["thumbnails"][-1]["url"]
52
- result = {
53
- "video_id": video_id,
54
- "thumbnail_url": thumbnail_url,
55
- "title": title,
56
- }
57
- results.append(result)
58
-
59
  return results
60
 
61
  @staticmethod
62
- def _search(
63
- query: Optional[str] = None, filter: SearchFilter = SearchFilter.VIDEO
64
- ) -> Dict[str, Any]:
65
  client = innertube.InnerTube("WEB", "2.20230920.00.00")
66
  response = client.search(query=query, params=filter.code if filter else None)
67
  return response
@@ -71,7 +64,7 @@ class SearchService:
71
  return f"https://www.youtube.com/watch?v={video_id}"
72
 
73
  @staticmethod
74
- def get_stream_url(youtube_url):
75
  try:
76
  session = streamlink.Streamlink()
77
  streams = session.streams(youtube_url)
@@ -79,27 +72,27 @@ class SearchService:
79
  best_stream = streams.get("best")
80
  return best_stream.url if best_stream else None
81
  else:
82
- logging.warning("No streams found for this URL")
83
  return None
84
  except Exception as e:
 
85
  logging.warning(f"An error occurred: {e}")
86
  return None
87
 
 
88
  INITIAL_STREAMS = SearchService.search("world live cams", SearchFilter.LIVE)
89
- class YouTubeObjectDetection:
 
 
90
  def __init__(self):
91
  logging.getLogger().setLevel(logging.DEBUG)
92
  self.model = YOLO("yolov8x.pt")
93
- self.font_path = self.download_font(
94
- "https://www.fontsquirrel.com/fonts/download/open-sans",
95
- "open-sans.zip",
96
- )
97
  self.current_page_token = None
98
  self.streams = INITIAL_STREAMS
99
 
100
- # Gradio UI Elements
101
  initial_gallery_items = [(stream["thumbnail_url"], stream["title"]) for stream in self.streams]
102
- self.gallery = gr.Gallery(label="Live YouTube Videos", value=initial_gallery_items, show_label=True, columns=[3], rows=[10], object_fit="contain", height="auto", allow_preview=False)
103
  self.search_input = gr.Textbox(label="Search Live YouTube Videos")
104
  self.stream_input = gr.Textbox(label="URL of Live YouTube Video")
105
  self.annotated_image = gr.AnnotatedImage(show_label=False)
@@ -107,22 +100,16 @@ class YouTubeObjectDetection:
107
  self.submit_button = gr.Button("Detect Objects", variant="primary", size="lg")
108
  self.page_title = gr.HTML("<center><h1><b>Object Detection in Live YouTube Streams</b></h1></center>")
109
 
110
-
111
- @staticmethod
112
- def download_font(url, save_path):
113
- os.system(f"wget {url} -O {save_path}")
114
- with zipfile.ZipFile(save_path, "r") as zip_ref:
115
- zip_ref.extractall(".")
116
- return os.path.join(".", "OpenSans-Regular.ttf")
117
-
118
- def capture_frame(self, url):
119
- stream_url = SearchService.get_stream_url(url)
120
  if not stream_url:
121
- return [], []
 
122
  frame = self.get_frame(stream_url)
123
  if frame is None:
124
- return self.create_error_image("Failed to capture frame"), []
125
- return self.process_frame(frame)
 
126
 
127
  def get_frame(self, stream_url):
128
  if not stream_url:
@@ -140,7 +127,7 @@ class YouTubeObjectDetection:
140
  logging.warning(f"An error occurred while capturing the frame: {e}")
141
  return None
142
 
143
- def process_frame(self, frame):
144
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
145
  results = self.model(frame_rgb)
146
  annotations = self.get_annotations(results)
@@ -152,35 +139,21 @@ class YouTubeObjectDetection:
152
  for box in result.boxes:
153
  class_id = int(box.cls[0])
154
  class_name = result.names[class_id]
155
- bbox = tuple(map(int, box.xyxy[0]))
156
- annotations.append((bbox, class_name))
 
 
157
  return annotations
158
 
159
- def create_error_image(self, message):
160
- error_image = np.zeros((1920, 1080, 3), dtype=np.uint8)
161
- pil_image = Image.fromarray(error_image)
162
- draw = ImageDraw.Draw(pil_image)
163
- font = ImageFont.truetype(self.font_path, 24)
164
- text_size = draw.textsize(message, font=font)
165
- position = ((1920 - text_size[0]) // 2, (1080 - text_size[1]) // 2)
166
- draw.text(position, message, (0, 0, 255), font=font)
167
- return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
168
-
169
- def fetch_live_streams(self, query=""):
170
- streams = []
171
- results = SearchService.search(query if query else "world live cams", SearchFilter.LIVE)
172
- for result in results:
173
- if "video_id" in result and "thumbnail_urls" in result:
174
- streams.append(
175
- {
176
- "thumbnail_url": result["thumbnail_urls"][0]
177
- if result["thumbnail_urls"] else "",
178
- "title": result["title"],
179
- "video_id": result["video_id"],
180
- "label": result["video_id"],
181
- }
182
- )
183
- return streams
184
 
185
  def render(self):
186
  with gr.Blocks(title="Object Detection in Live YouTube Streams", css="footer {visibility: hidden}") as app:
@@ -199,26 +172,26 @@ class YouTubeObjectDetection:
199
  self.gallery.render()
200
 
201
  @self.gallery.select(inputs=None, outputs=[self.annotated_image, self.stream_input])
202
- def on_gallery_select(evt: gr.SelectData):
203
- selected_index = evt.index
204
- if selected_index is not None and selected_index < len(self.streams):
205
- selected_stream = self.streams[selected_index]
206
  stream_url = SearchService.get_youtube_url(selected_stream["video_id"])
207
- frame_output = self.capture_frame(stream_url)
208
  return frame_output, stream_url
209
  return None, ""
210
 
211
  @self.search_button.click(inputs=[self.search_input], outputs=[self.gallery])
212
- def on_search_click(query):
213
- self.streams = self.fetch_live_streams(query)
214
  gallery_items = [(stream["thumbnail_url"], stream["title"]) for stream in self.streams]
215
  return gallery_items
216
 
217
  @self.submit_button.click(inputs=[self.stream_input], outputs=[self.annotated_image])
218
- def annotate_stream(url):
219
- return self.capture_frame(url)
 
 
220
 
221
- return app.queue().launch(show_api=False, debug=True, quiet=False, share=False)
222
 
223
  if __name__ == "__main__":
224
- YouTubeObjectDetection().render()
 
1
  import logging
 
2
  import sys
 
3
  from enum import Enum
4
+ from typing import Any, Dict, Optional
5
+
6
  import cv2
7
  import gradio as gr
8
  import innertube
9
  import numpy as np
10
  import streamlink
11
+ from PIL import Image
12
  from ultralytics import YOLO
13
 
14
  logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
 
27
  def __str__(self):
28
  return self.human_readable
29
 
30
+
31
  class SearchService:
32
  @staticmethod
33
+ def search(query: Optional[str], filter: SearchFilter = SearchFilter.VIDEO):
 
 
34
  client = innertube.InnerTube("WEB", "2.20230920.00.00")
35
  response = SearchService._search(query, filter)
36
  results = SearchService.parse(response)
37
  return results
38
 
39
  @staticmethod
40
+ def parse(data: Dict[str, Any]):
41
  results = []
42
  contents = data["contents"]["twoColumnSearchResultsRenderer"]["primaryContents"]["sectionListRenderer"]["contents"]
43
  items = contents[0]["itemSectionRenderer"]["contents"] if contents else []
44
  for item in items:
45
  if "videoRenderer" in item:
46
  renderer = item["videoRenderer"]
47
+ results.append(
48
+ {
49
+ "video_id": renderer["videoId"],
50
+ "thumbnail_url": renderer["thumbnail"]["thumbnails"][-1]["url"],
51
+ "title": "".join(run["text"] for run in renderer["title"]["runs"]),
52
+ }
53
+ )
 
 
 
54
  return results
55
 
56
  @staticmethod
57
+ def _search(query: Optional[str] = None, filter: SearchFilter = SearchFilter.VIDEO):
 
 
58
  client = innertube.InnerTube("WEB", "2.20230920.00.00")
59
  response = client.search(query=query, params=filter.code if filter else None)
60
  return response
 
64
  return f"https://www.youtube.com/watch?v={video_id}"
65
 
66
  @staticmethod
67
+ def get_stream(youtube_url):
68
  try:
69
  session = streamlink.Streamlink()
70
  streams = session.streams(youtube_url)
 
72
  best_stream = streams.get("best")
73
  return best_stream.url if best_stream else None
74
  else:
75
+ gr.Warning(f"No streams found for: {youtube_url}")
76
  return None
77
  except Exception as e:
78
+ gr.Error(f"An error occurred while getting stream: {e}")
79
  logging.warning(f"An error occurred: {e}")
80
  return None
81
 
82
+
83
  INITIAL_STREAMS = SearchService.search("world live cams", SearchFilter.LIVE)
84
+
85
+
86
+ class LiveYouTubeObjectDetector:
87
  def __init__(self):
88
  logging.getLogger().setLevel(logging.DEBUG)
89
  self.model = YOLO("yolov8x.pt")
 
 
 
 
90
  self.current_page_token = None
91
  self.streams = INITIAL_STREAMS
92
 
93
+ # Gradio UI
94
  initial_gallery_items = [(stream["thumbnail_url"], stream["title"]) for stream in self.streams]
95
+ self.gallery = gr.Gallery(label="Live YouTube Videos", value=initial_gallery_items, show_label=True, columns=[4], rows=[5], object_fit="contain", height="auto", allow_preview=False)
96
  self.search_input = gr.Textbox(label="Search Live YouTube Videos")
97
  self.stream_input = gr.Textbox(label="URL of Live YouTube Video")
98
  self.annotated_image = gr.AnnotatedImage(show_label=False)
 
100
  self.submit_button = gr.Button("Detect Objects", variant="primary", size="lg")
101
  self.page_title = gr.HTML("<center><h1><b>Object Detection in Live YouTube Streams</b></h1></center>")
102
 
103
+ def detect_objects(self, url):
104
+ stream_url = SearchService.get_stream(url)
 
 
 
 
 
 
 
 
105
  if not stream_url:
106
+ gr.Error(f"Unable to find a stream for: {stream_url}")
107
+ return self.create_black_image(), []
108
  frame = self.get_frame(stream_url)
109
  if frame is None:
110
+ gr.Error(f"Unable to capture frame for: {stream_url}")
111
+ return self.create_black_image(), []
112
+ return self.annotate(frame)
113
 
114
  def get_frame(self, stream_url):
115
  if not stream_url:
 
127
  logging.warning(f"An error occurred while capturing the frame: {e}")
128
  return None
129
 
130
+ def annotate(self, frame):
131
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
132
  results = self.model(frame_rgb)
133
  annotations = self.get_annotations(results)
 
139
  for box in result.boxes:
140
  class_id = int(box.cls[0])
141
  class_name = result.names[class_id]
142
+ # EXTRACT BOUNDING BOX AND CONVERT TO INTEGER
143
+ x1, y1, x2, y2 = box.xyxy[0]
144
+ bbox_coords = (int(x1), int(y1), int(x2), int(y2))
145
+ annotations.append((bbox_coords, class_name))
146
  return annotations
147
 
148
+ def create_black_image():
149
+ black_image = np.zeros((1080, 1920, 3), dtype=np.uint8)
150
+ pil_black_image = Image.fromarray(black_image)
151
+ cv2_black_image = cv2.cvtColor(np.array(pil_black_image), cv2.COLOR_RGB2BGR)
152
+ return cv2_black_image
153
+
154
+ @staticmethod
155
+ def get_live_streams(query=""):
156
+ return SearchService.search(query if query else "world live cams", SearchFilter.LIVE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  def render(self):
159
  with gr.Blocks(title="Object Detection in Live YouTube Streams", css="footer {visibility: hidden}") as app:
 
172
  self.gallery.render()
173
 
174
  @self.gallery.select(inputs=None, outputs=[self.annotated_image, self.stream_input])
175
+ def detect_objects_from_gallery_item(evt: gr.SelectData):
176
+ if evt.index is not None and evt.index < len(self.streams):
177
+ selected_stream = self.streams[evt.index]
 
178
  stream_url = SearchService.get_youtube_url(selected_stream["video_id"])
179
+ frame_output = self.detect_objects(stream_url)
180
  return frame_output, stream_url
181
  return None, ""
182
 
183
  @self.search_button.click(inputs=[self.search_input], outputs=[self.gallery])
184
+ def search_live_streams(query):
185
+ self.streams = self.get_live_streams(query)
186
  gallery_items = [(stream["thumbnail_url"], stream["title"]) for stream in self.streams]
187
  return gallery_items
188
 
189
  @self.submit_button.click(inputs=[self.stream_input], outputs=[self.annotated_image])
190
+ def detect_objects_from_url(url):
191
+ return self.detect_objects(url)
192
+
193
+ return app.queue().launch(show_api=False, debug=True, quiet=False, share=True)
194
 
 
195
 
196
  if __name__ == "__main__":
197
+ LiveYouTubeObjectDetector().render()