jonathanli's picture
Update app.py
7d14eed
import gradio as gr
import requests
import re
from transformers import AutoTokenizer, pipeline
from youtube_transcript_api._transcripts import TranscriptListFetcher
tagger = pipeline(
"token-classification",
"./checkpoint-6000",
aggregation_strategy="first",
)
tokenizer = AutoTokenizer.from_pretrained("./checkpoint-6000")
max_size = 512
classes = [False, True]
pattern = re.compile(
r"(?:https?:\/\/)?(?:[0-9A-Z-]+\.)?(?:youtube|youtu|youtube-nocookie)\.(?:com|be)\/(?:watch\?v=|watch\?.+&v=|embed\/|v\/|.+\?v=)?([^&=\n%\?]{11})"
)
def video_id(url):
p = pattern.match(url)
return p.group(1) if p else None
def process(obj):
o = obj["events"]
new_l = []
start_dur = None
for line in o:
if "segs" in line:
if len(line["segs"]) == 1 and line["segs"][0]["utf8"] == "\n":
if start_dur is not None:
new_l.append(
{
"w": prev["utf8"],
"s": start_dur + prev["tOffsetMs"],
"e": line["tStartMs"],
}
)
continue
start_dur = line["tStartMs"]
prev = line["segs"][0]
prev["tOffsetMs"] = 0
for word in line["segs"][1:]:
try:
new_l.append(
{
"w": prev["utf8"],
"s": start_dur + prev["tOffsetMs"],
"e": start_dur + word["tOffsetMs"],
}
)
prev = word
except KeyError:
pass
return new_l
def get_transcript(video_id, session):
fetcher = TranscriptListFetcher(session)
_json = fetcher._extract_captions_json(
fetcher._fetch_video_html(video_id), video_id
)
captionTracks = _json["captionTracks"]
transcript_track_url = ""
for track in captionTracks:
if track["languageCode"] == "en":
transcript_track_url = track["baseUrl"] + "&fmt=json3"
if not transcript_track_url:
return None
obj = session.get(transcript_track_url)
p = process(obj.json())
return p
def transcript(url):
i = video_id(url)
if i:
return " ".join(l["w"].strip() for l in get_transcript(i, requests.Session()))
else:
return "ERROR: Failed to load transcript (it the link a valid youtube url?)..."
def inference(transcript):
tokens = tokenizer(transcript.split(" "))["input_ids"]
current_length = 0
current_word_length = 0
batches = []
for i, w in enumerate(tokens):
word = w[:-1] if i == 0 else w[1:] if i == (len(tokens) - 1) else w[1:-1]
if (current_length + len(word)) > max_size:
batch = " ".join(
tokenizer.batch_decode(
[
tok[1:-1]
for tok in tokens[max(0, i - current_word_length - 1) : i]
]
)
)
batches.append(batch)
current_word_length = 0
current_length = 0
continue
current_length += len(word)
current_word_length += 1
if current_length > 0:
batches.append(
" ".join(
tokenizer.batch_decode(
[tok[1:-1] for tok in tokens[i - current_word_length :]]
)
)
)
results = []
for split in batches:
values = tagger(split)
results.extend(
{
"sponsor": v["entity_group"] == "LABEL_1",
"phrase": v["word"],
}
for v in values
)
return results
def predict(transcript):
return [(span["phrase"], "Sponsor" if span["sponsor"] else None) for span in inference(transcript)]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
inp = gr.Textbox(label="Video URL", placeholder="Video URL", lines=1, max_lines=1)
btn = gr.Button("Fetch Transcript")
gr.Examples(["youtu.be/xsLJZyih3Ac"], [inp])
text = gr.Textbox(label="Transcript", placeholder="<generated transcript>")
btn.click(fn=transcript, inputs=inp, outputs=text)
with gr.Column():
p = gr.Button("Predict Sponsors")
highlight = gr.HighlightedText()
p.click(fn=predict, inputs=text, outputs=highlight)
demo.launch()