jonathanagustin commited on
Commit
b864380
1 Parent(s): be7e9ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -2
app.py CHANGED
@@ -1,2 +1,283 @@
1
- import streamlit as st
2
- st.write("Hello, World!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import zipfile
5
+ from enum import Enum
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ os.system("python3 -m pip uninstall -y typing-extensions")
9
+ os.system("python3 -m pip install -U typing-extensions")
10
+ os.system(
11
+ "python3 -m pip install -q --progress-bar off streamlink gradio tiktoken ultralytics pillow innertube"
12
+ )
13
+ import cv2
14
+ import gradio as gr
15
+ import innertube
16
+ import numpy as np
17
+ import streamlink
18
+ from PIL import Image, ImageDraw, ImageFont
19
+ from ultralytics import YOLO
20
+
21
+ logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
22
+
23
+ model = YOLO("yolov8x.pt")
24
+
25
+
26
+ class SearchFilter(Enum):
27
+ LIVE = ("EgJAAQ%3D%3D", "Live")
28
+ VIDEO = ("EgIQAQ%3D%3D", "Video")
29
+
30
+ def __init__(self, code, human_readable):
31
+ self.code = code
32
+ self.human_readable = human_readable
33
+
34
+ def __str__(self):
35
+ return self.human_readable
36
+
37
+
38
+ class SearchService:
39
+ @staticmethod
40
+ def search(
41
+ query: Optional[str], filter: SearchFilter = SearchFilter.VIDEO
42
+ ) -> (List[Dict[str, Any]], Optional[str]):
43
+ client = innertube.InnerTube("WEB", "2.20230920.00.00")
44
+ response = SearchService._search(query, filter)
45
+ results = SearchService.parse(response)
46
+ return results
47
+
48
+ @staticmethod
49
+ def parse(data: Dict[str, Any]) -> List[Dict[str, Any]]:
50
+ results = []
51
+ items = []
52
+
53
+ contents = (
54
+ data.get("contents", {})
55
+ .get("twoColumnSearchResultsRenderer", {})
56
+ .get("primaryContents", {})
57
+ .get("sectionListRenderer", {})
58
+ .get("contents", [])
59
+ )
60
+ if contents:
61
+ items = contents[0].get("itemSectionRenderer", {}).get("contents", [])
62
+
63
+ for item in items:
64
+ if "videoRenderer" in item:
65
+ renderer = item["videoRenderer"]
66
+ video_id = renderer.get("videoId", "")
67
+ thumbnail_urls = [
68
+ thumb.get("url", "")
69
+ for thumb in renderer.get("thumbnail", {}).get("thumbnails", [])
70
+ ]
71
+ title_text = "".join(
72
+ [
73
+ run.get("text", "")
74
+ for run in renderer.get("title", {}).get("runs", [])
75
+ ]
76
+ )
77
+
78
+ result = {
79
+ "video_id": video_id,
80
+ "thumbnail_urls": thumbnail_urls,
81
+ "title": title_text,
82
+ }
83
+ results.append(result)
84
+
85
+ return results
86
+
87
+ @staticmethod
88
+ def _search(
89
+ query: Optional[str] = None, filter: SearchFilter = SearchFilter.VIDEO
90
+ ) -> Dict[str, Any]:
91
+ client = innertube.InnerTube("WEB", "2.20230920.00.00")
92
+ response = client.search(query=query, params=filter.code if filter else None)
93
+ return response
94
+
95
+ @staticmethod
96
+ def get_youtube_url(video_id: str) -> str:
97
+ return f"https://www.youtube.com/watch?v={video_id}"
98
+
99
+ @staticmethod
100
+ def get_stream_url(youtube_url):
101
+ try:
102
+ session = streamlink.Streamlink()
103
+ streams = session.streams(youtube_url)
104
+ if streams:
105
+ best_stream = streams.get("best")
106
+ return best_stream.url if best_stream else None
107
+ else:
108
+ logging.warning("No streams found for this URL")
109
+ return None
110
+ except Exception as e:
111
+ logging.warning(f"An error occurred: {e}")
112
+ return None
113
+
114
+
115
+ class LiveStreamAnnotator:
116
+ def __init__(self):
117
+ logging.getLogger().setLevel(logging.DEBUG)
118
+ self.model = YOLO("yolov8x.pt")
119
+ self.font_path = self.download_font(
120
+ "https://www.fontsquirrel.com/fonts/download/open-sans",
121
+ "/content/open-sans.zip",
122
+ )
123
+ self.current_page_token = None
124
+ self.streams = self.fetch_live_streams("world live cams")
125
+ # Gradio UI Elements
126
+ initial_gallery_items = [
127
+ (stream["thumbnail_url"], stream["title"]) for stream in self.streams
128
+ ]
129
+ self.gallery = gr.Gallery(
130
+ label="Live YouTube Videos",
131
+ value=initial_gallery_items,
132
+ show_label=False,
133
+ columns=[3],
134
+ rows=[10],
135
+ object_fit="contain",
136
+ height="auto",
137
+ )
138
+ self.search_input = gr.Textbox(label="Search Live YouTube Videos")
139
+ self.stream_input = gr.Textbox(label="URL of Live YouTube Video")
140
+ self.output_image = gr.AnnotatedImage(show_label=False)
141
+ self.search_button = gr.Button("Search")
142
+ self.submit_button = gr.Button("Detect Objects", variant="primary", size="lg")
143
+ self.prev_page_button = gr.Button("Previous Page", interactive=False)
144
+ self.next_page_button = gr.Button("Next Page", interactive=False)
145
+
146
+ @staticmethod
147
+ def download_font(url, save_path):
148
+ os.system(f"wget {url} -O {save_path}")
149
+ with zipfile.ZipFile(save_path, "r") as zip_ref:
150
+ font_dir = "/usr/share/fonts/open-sans"
151
+ zip_ref.extractall(font_dir)
152
+ return os.path.join(font_dir, "OpenSans-Regular.ttf")
153
+
154
+ def capture_frame(self, url):
155
+ stream_url = SearchService.get_stream_url(url)
156
+ if not stream_url:
157
+ return self.create_error_image("No stream found"), []
158
+ frame = self.get_frame(stream_url)
159
+ if frame is None:
160
+ return self.create_error_image("Failed to capture frame"), []
161
+ return self.process_frame(frame)
162
+
163
+ def get_frame(self, stream_url):
164
+ if not stream_url:
165
+ return None
166
+ try:
167
+ cap = cv2.VideoCapture(stream_url)
168
+ ret, frame = cap.read()
169
+ cap.release()
170
+ if ret:
171
+ return cv2.resize(frame, (1920, 1080))
172
+ else:
173
+ logging.warning(
174
+ "Unable to process the HLS stream with cv2.VideoCapture."
175
+ )
176
+ return None
177
+ except Exception as e:
178
+ logging.warning(f"An error occurred while capturing the frame: {e}")
179
+ return None
180
+
181
+ def process_frame(self, frame):
182
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
183
+ results = self.model(frame_rgb)
184
+ annotations = self.get_annotations(results)
185
+ return Image.fromarray(frame_rgb), annotations
186
+
187
+ def get_annotations(self, results):
188
+ annotations = []
189
+ for result in results:
190
+ for box in result.boxes:
191
+ class_id = int(box.cls[0])
192
+ class_name = result.names[class_id]
193
+ bbox = tuple(map(int, box.xyxy[0]))
194
+ annotations.append((bbox, class_name))
195
+ return annotations
196
+
197
+ @staticmethod
198
+ def create_error_image(message):
199
+ error_image = np.zeros((1920, 1080, 3), dtype=np.uint8)
200
+ pil_image = Image.fromarray(error_image)
201
+ draw = ImageDraw.Draw(pil_image)
202
+ font = ImageFont.truetype("/usr/share/fonts/open-sans/OpenSans-Regular.ttf", 24)
203
+ text_size = draw.textsize(message, font=font)
204
+ position = ((1920 - text_size[0]) // 2, (1080 - text_size[1]) // 2)
205
+ draw.text(position, message, (0, 0, 255), font=font)
206
+ return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
207
+
208
+ def fetch_live_streams(self, query=""):
209
+ streams = []
210
+ results = SearchService.search(
211
+ query if query else "world live cams", SearchFilter.LIVE
212
+ )
213
+ for result in results:
214
+ if "video_id" in result and "thumbnail_urls" in result:
215
+ streams.append(
216
+ {
217
+ "thumbnail_url": result["thumbnail_urls"][0]
218
+ if result["thumbnail_urls"]
219
+ else "",
220
+ "title": result["title"],
221
+ "video_id": result["video_id"],
222
+ "label": result["video_id"],
223
+ }
224
+ )
225
+ return streams
226
+
227
+ def render(self):
228
+ with gr.Blocks(
229
+ title="Object Detection in Live YouTube Streams",
230
+ css="footer {visibility: hidden}",
231
+ ) as app:
232
+ gr.HTML(
233
+ "<center><h1><b>Object Detection in Live YouTube Streams</b></h1></center>"
234
+ )
235
+ with gr.Column():
236
+ self.stream_input.render()
237
+ with gr.Group():
238
+ self.output_image.render()
239
+ self.submit_button.render()
240
+ with gr.Group():
241
+ with gr.Row():
242
+ self.search_input.render()
243
+ self.search_button.render()
244
+ with gr.Row():
245
+ self.gallery.render()
246
+
247
+ @self.gallery.select(
248
+ inputs=None, outputs=[self.output_image, self.stream_input]
249
+ )
250
+ def on_gallery_select(evt: gr.SelectData):
251
+ selected_index = evt.index
252
+ if selected_index is not None and selected_index < len(self.streams):
253
+ selected_stream = self.streams[selected_index]
254
+ stream_url = SearchService.get_youtube_url(
255
+ selected_stream["video_id"]
256
+ )
257
+ frame_output = self.capture_frame(stream_url)
258
+ return frame_output, stream_url
259
+ return None, ""
260
+
261
+ @self.search_button.click(
262
+ inputs=[self.search_input], outputs=[self.gallery]
263
+ )
264
+ def on_search_click(query):
265
+ self.streams = self.fetch_live_streams(query)
266
+ gallery_items = [
267
+ (stream["thumbnail_url"], stream["title"])
268
+ for stream in self.streams
269
+ ]
270
+ return gallery_items
271
+
272
+ @self.submit_button.click(
273
+ inputs=[self.stream_input], outputs=[self.output_image]
274
+ )
275
+ def annotate_stream(url):
276
+ return self.capture_frame(url)
277
+
278
+ app.queue().launch(
279
+ share=True, debug=True, quiet=True, show_api=False, height=800
280
+ )
281
+
282
+
283
+ LiveStreamAnnotator().render()