|
from smolagents import Tool |
|
from openai import OpenAI |
|
from .speech_recognition_tool import SpeechRecognitionTool |
|
from io import BytesIO |
|
import yt_dlp |
|
import av |
|
import torchaudio |
|
import subprocess |
|
import requests |
|
import base64 |
|
|
|
|
|
class YoutubeVideoTool(Tool): |
|
name = "youtube_video" |
|
description = """Process the video and return the requested information from it.""" |
|
inputs = { |
|
"url": { |
|
"type": "string", |
|
"description": "The URL of the YouTube video.", |
|
}, |
|
"query": { |
|
"type": "string", |
|
"description": "The question to answer.", |
|
}, |
|
} |
|
output_type = "string" |
|
|
|
def __init__( |
|
self, |
|
video_quality: int = 360, |
|
frames_interval: int | float | None = 2, |
|
chunk_duration: int | float | None = 20, |
|
speech_recognition_tool: SpeechRecognitionTool | None = None, |
|
client: OpenAI | None = None, |
|
model_id: str = "gpt-4.1-mini", |
|
debug: bool = False, |
|
**kwargs, |
|
): |
|
self.video_quality = video_quality |
|
self.speech_recognition_tool = speech_recognition_tool |
|
self.frames_interval = frames_interval |
|
self.chunk_duration = chunk_duration |
|
|
|
self.client = client or OpenAI() |
|
self.model_id = model_id |
|
|
|
self.debug = debug |
|
|
|
super().__init__(**kwargs) |
|
|
|
def forward(self, url: str, query: str): |
|
""" |
|
Process the video and return the requested information. |
|
Args: |
|
url (str): The URL of the YouTube video. |
|
query (str): The question to answer. |
|
Returns: |
|
str: Answer to the query. |
|
""" |
|
answer = "" |
|
for chunk in self._split_video_into_chunks(url): |
|
prompt = self._prompt( |
|
chunk, |
|
query, |
|
answer, |
|
) |
|
response = self.client.responses.create( |
|
model="gpt-4.1-mini", |
|
input=[ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "input_text", |
|
"text": prompt, |
|
}, |
|
*[ |
|
{ |
|
"type": "input_image", |
|
"image_url": f"data:image/jpeg;base64,{frame}", |
|
} |
|
for frame in self._base64_frames(chunk["frames"]) |
|
], |
|
], |
|
} |
|
], |
|
) |
|
answer = response.output_text |
|
if self.debug: |
|
print( |
|
f"CHUNK {chunk['start']} - {chunk['end']}:\n\n{prompt}\n\nANSWER:\n{answer}" |
|
) |
|
|
|
if answer.strip() == "I need to keep watching": |
|
answer = "" |
|
return answer |
|
|
|
def _prompt(self, chunk, query, aggregated_answer): |
|
prompt = [ |
|
f"""\ |
|
These are some frames of a video that I want to upload. |
|
I will ask a question about the entire video, but I will only last part of it. |
|
Aggregate answer about the entire video, use information about previous parts but do not reference the previous parts in the answer directly. |
|
|
|
Ground your answer based on video title, description, captions, vide frames or answer from previous parts. |
|
If no evidences presented just say "I need to keep watching". |
|
|
|
VIDEO TITLE: |
|
{chunk["title"]} |
|
|
|
VIDEO DESCRIPTION: |
|
{chunk["description"]} |
|
|
|
FRAMES SUBTITLES: |
|
{chunk["captions"]}""" |
|
] |
|
|
|
if aggregated_answer: |
|
prompt.append(f"""\ |
|
Here is the answer to the same question based on the previous video parts: |
|
|
|
BASED ON PREVIOUS PARTS: |
|
{aggregated_answer}""") |
|
|
|
prompt.append(f"""\ |
|
|
|
QUESTION: |
|
{query}""") |
|
|
|
return "\n\n".join(prompt) |
|
|
|
def _split_video_into_chunks( |
|
self, url: str, with_captions: bool = True, with_frames: bool = True |
|
): |
|
video = self._process_video( |
|
url, with_captions=with_captions, with_frames=with_frames |
|
) |
|
video_duration = video["duration"] |
|
chunk_duration = self.chunk_duration or video_duration |
|
|
|
chunk_start = 0.0 |
|
while chunk_start < video_duration: |
|
chunk_end = min(chunk_start + chunk_duration, video_duration) |
|
chunk = self._get_video_chunk(video, chunk_start, chunk_end) |
|
yield chunk |
|
chunk_start += chunk_duration |
|
|
|
def _get_video_chunk(self, video, start, end): |
|
chunk_captions = [ |
|
c for c in video["captions"] if c["start"] <= end and c["end"] >= start |
|
] |
|
chunk_frames = [ |
|
f |
|
for f in video["frames"] |
|
if f["timestamp"] >= start and f["timestamp"] <= end |
|
] |
|
|
|
return { |
|
"title": video["title"], |
|
"description": video["description"], |
|
"start": start, |
|
"end": end, |
|
"captions": "\n".join([c["text"] for c in chunk_captions]), |
|
"frames": chunk_frames, |
|
} |
|
|
|
def _process_video( |
|
self, url: str, with_captions: bool = True, with_frames: bool = True |
|
): |
|
lang = "en" |
|
info = self._get_video_info(url, lang) |
|
|
|
if with_captions: |
|
captions = self._extract_captions( |
|
lang, info.get("subtitles", {}), info.get("automatic_captions", {}) |
|
) |
|
if not captions and self.speech_recognition_tool: |
|
audio_url = self._select_audio_format(info["formats"]) |
|
audio = self._capture_audio(audio_url) |
|
waveform, sample_rate = torchaudio.load(audio) |
|
assert sample_rate == 16000 |
|
waveform_np = waveform.squeeze().numpy() |
|
captions = self.speech_recognition_tool.transcribe(waveform_np) |
|
else: |
|
captions = [] |
|
|
|
if with_frames: |
|
video_url = self._select_video_format(info["formats"], 360)["url"] |
|
frames = self._capture_video_frames(video_url, self.frames_interval) |
|
else: |
|
frames = [] |
|
|
|
return { |
|
"id": info["id"], |
|
"title": info["title"], |
|
"description": info["description"], |
|
"duration": info["duration"], |
|
"captions": captions, |
|
"frames": frames, |
|
} |
|
|
|
def _get_video_info(self, url: str, lang: str): |
|
ydl_opts = { |
|
"quiet": True, |
|
"skip_download": True, |
|
"format": "bestvideo[ext=mp4][height<=360]+bestaudio[ext=m4a]/best[height<=360]", |
|
"forceurl": True, |
|
"noplaylist": True, |
|
"writesubtitles": True, |
|
"writeautomaticsub": True, |
|
"subtitlesformat": "vtt", |
|
"subtitleslangs": [lang], |
|
} |
|
|
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl: |
|
info = ydl.extract_info(url, download=False) |
|
|
|
return info |
|
|
|
def _extract_captions(self, lang, subtitles, auto_captions): |
|
caption_tracks = subtitles.get(lang) or auto_captions.get(lang) or [] |
|
|
|
structured_captions = [] |
|
|
|
srt_track = next( |
|
(track for track in caption_tracks if track["ext"] == "srt"), None |
|
) |
|
vtt_track = next( |
|
(track for track in caption_tracks if track["ext"] == "vtt"), None |
|
) |
|
|
|
if srt_track: |
|
import pysrt |
|
|
|
response = requests.get(srt_track["url"]) |
|
response.raise_for_status() |
|
srt_data = response.content.decode("utf-8") |
|
|
|
def to_sec(t): |
|
return ( |
|
t.hours * 3600 + t.minutes * 60 + t.seconds + t.milliseconds / 1000 |
|
) |
|
|
|
structured_captions = [ |
|
{ |
|
"start": to_sec(sub.start), |
|
"end": to_sec(sub.end), |
|
"text": sub.text.strip(), |
|
} |
|
for sub in pysrt.from_str(srt_data) |
|
] |
|
if vtt_track: |
|
import webvtt |
|
from io import StringIO |
|
|
|
response = requests.get(vtt_track["url"]) |
|
response.raise_for_status() |
|
vtt_data = response.text |
|
|
|
vtt_file = StringIO(vtt_data) |
|
|
|
def to_sec(t): |
|
"""Convert 'HH:MM:SS.mmm' to float seconds""" |
|
h, m, s = t.split(":") |
|
s, ms = s.split(".") |
|
return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000 |
|
|
|
for caption in webvtt.read_buffer(vtt_file): |
|
structured_captions.append( |
|
{ |
|
"start": to_sec(caption.start), |
|
"end": to_sec(caption.end), |
|
"text": caption.text.strip(), |
|
} |
|
) |
|
return structured_captions |
|
|
|
def _select_video_format(self, formats, video_quality): |
|
video_format = next( |
|
f |
|
for f in formats |
|
if f.get("vcodec") != "none" and f.get("height") == video_quality |
|
) |
|
return video_format |
|
|
|
def _capture_video_frames(self, video_url, capture_interval_sec=None): |
|
ffmpeg_cmd = [ |
|
"ffmpeg", |
|
"-i", |
|
video_url, |
|
"-f", |
|
"matroska", |
|
"-", |
|
] |
|
|
|
process = subprocess.Popen( |
|
ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL |
|
) |
|
|
|
container = av.open(process.stdout) |
|
stream = container.streams.video[0] |
|
time_base = stream.time_base |
|
|
|
frames = [] |
|
next_capture_time = 0 |
|
for frame in container.decode(stream): |
|
if frame.pts is None: |
|
continue |
|
|
|
timestamp = float(frame.pts * time_base) |
|
if capture_interval_sec is None or timestamp >= next_capture_time: |
|
frames.append( |
|
{ |
|
"timestamp": timestamp, |
|
"image": frame.to_image(), |
|
} |
|
) |
|
if capture_interval_sec is not None: |
|
next_capture_time += capture_interval_sec |
|
|
|
process.terminate() |
|
return frames |
|
|
|
def _base64_frames(self, frames): |
|
base64_frames = [] |
|
for f in frames: |
|
buffered = BytesIO() |
|
f["image"].save(buffered, format="JPEG") |
|
encoded = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
base64_frames.append(encoded) |
|
return base64_frames |
|
|
|
def _select_audio_format(self, formats): |
|
audio_formats = [ |
|
f |
|
for f in formats |
|
if f.get("vcodec") == "none" |
|
and f.get("acodec") |
|
and f.get("acodec") != "none" |
|
] |
|
|
|
if not audio_formats: |
|
raise ValueError("No valid audio-only formats found.") |
|
|
|
|
|
preferred_exts = ["m4a", "webm"] |
|
|
|
def sort_key(f): |
|
ext_score = ( |
|
preferred_exts.index(f["ext"]) if f["ext"] in preferred_exts else 99 |
|
) |
|
abr = f.get("abr") or 0 |
|
return (ext_score, -abr) |
|
|
|
audio_formats.sort(key=sort_key) |
|
return audio_formats[0]["url"] |
|
|
|
def _capture_audio(self, audio_url) -> BytesIO: |
|
audio_buffer = BytesIO() |
|
ffmpeg_audio_cmd = [ |
|
"ffmpeg", |
|
"-i", |
|
audio_url, |
|
"-f", |
|
"wav", |
|
"-acodec", |
|
"pcm_s16le", |
|
"-ac", |
|
"1", |
|
"-ar", |
|
"16000", |
|
"-", |
|
] |
|
|
|
result = subprocess.run( |
|
ffmpeg_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE |
|
) |
|
if result.returncode != 0: |
|
raise RuntimeError("ffmpeg failed:\n" + result.stderr.decode()) |
|
|
|
audio_buffer = BytesIO(result.stdout) |
|
audio_buffer.seek(0) |
|
return audio_buffer |
|
|