Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
import re | |
from transformers import AutoTokenizer, pipeline | |
from youtube_transcript_api._transcripts import TranscriptListFetcher | |
tagger = pipeline( | |
"token-classification", | |
"./checkpoint-6000", | |
aggregation_strategy="first", | |
) | |
tokenizer = AutoTokenizer.from_pretrained("./checkpoint-6000") | |
max_size = 512 | |
classes = [False, True] | |
pattern = re.compile( | |
r"(?:https?:\/\/)?(?:[0-9A-Z-]+\.)?(?:youtube|youtu|youtube-nocookie)\.(?:com|be)\/(?:watch\?v=|watch\?.+&v=|embed\/|v\/|.+\?v=)?([^&=\n%\?]{11})" | |
) | |
def video_id(url): | |
p = pattern.match(url) | |
return p.group(1) if p else None | |
def process(obj): | |
o = obj["events"] | |
new_l = [] | |
start_dur = None | |
for line in o: | |
if "segs" in line: | |
if len(line["segs"]) == 1 and line["segs"][0]["utf8"] == "\n": | |
if start_dur is not None: | |
new_l.append( | |
{ | |
"w": prev["utf8"], | |
"s": start_dur + prev["tOffsetMs"], | |
"e": line["tStartMs"], | |
} | |
) | |
continue | |
start_dur = line["tStartMs"] | |
prev = line["segs"][0] | |
prev["tOffsetMs"] = 0 | |
for word in line["segs"][1:]: | |
try: | |
new_l.append( | |
{ | |
"w": prev["utf8"], | |
"s": start_dur + prev["tOffsetMs"], | |
"e": start_dur + word["tOffsetMs"], | |
} | |
) | |
prev = word | |
except KeyError: | |
pass | |
return new_l | |
def get_transcript(video_id, session): | |
fetcher = TranscriptListFetcher(session) | |
_json = fetcher._extract_captions_json( | |
fetcher._fetch_video_html(video_id), video_id | |
) | |
captionTracks = _json["captionTracks"] | |
transcript_track_url = "" | |
for track in captionTracks: | |
if track["languageCode"] == "en": | |
transcript_track_url = track["baseUrl"] + "&fmt=json3" | |
if not transcript_track_url: | |
return None | |
obj = session.get(transcript_track_url) | |
p = process(obj.json()) | |
return p | |
def transcript(url): | |
i = video_id(url) | |
if i: | |
return " ".join(l["w"].strip() for l in get_transcript(i, requests.Session())) | |
else: | |
return "ERROR: Failed to load transcript (it the link a valid youtube url?)..." | |
def inference(transcript): | |
tokens = tokenizer(transcript.split(" "))["input_ids"] | |
current_length = 0 | |
current_word_length = 0 | |
batches = [] | |
for i, w in enumerate(tokens): | |
word = w[:-1] if i == 0 else w[1:] if i == (len(tokens) - 1) else w[1:-1] | |
if (current_length + len(word)) > max_size: | |
batch = " ".join( | |
tokenizer.batch_decode( | |
[ | |
tok[1:-1] | |
for tok in tokens[max(0, i - current_word_length - 1) : i] | |
] | |
) | |
) | |
batches.append(batch) | |
current_word_length = 0 | |
current_length = 0 | |
continue | |
current_length += len(word) | |
current_word_length += 1 | |
if current_length > 0: | |
batches.append( | |
" ".join( | |
tokenizer.batch_decode( | |
[tok[1:-1] for tok in tokens[i - current_word_length :]] | |
) | |
) | |
) | |
results = [] | |
for split in batches: | |
values = tagger(split) | |
results.extend( | |
{ | |
"sponsor": v["entity_group"] == "LABEL_1", | |
"phrase": v["word"], | |
} | |
for v in values | |
) | |
return results | |
def predict(transcript): | |
return [(span["phrase"], "Sponsor" if span["sponsor"] else None) for span in inference(transcript)] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
inp = gr.Textbox(label="Video URL", placeholder="Video URL", lines=1, max_lines=1) | |
btn = gr.Button("Fetch Transcript") | |
gr.Examples(["youtu.be/xsLJZyih3Ac"], [inp]) | |
text = gr.Textbox(label="Transcript", placeholder="<generated transcript>") | |
btn.click(fn=transcript, inputs=inp, outputs=text) | |
with gr.Column(): | |
p = gr.Button("Predict Sponsors") | |
highlight = gr.HighlightedText() | |
p.click(fn=predict, inputs=text, outputs=highlight) | |
demo.launch() | |