Spaces:
Runtime error
Runtime error
Add time conversions from outputs
Browse files- app.py +4 -1
- functions/convert_time.py +50 -0
- functions/model_infer.py +1 -1
- functions/punctuation.py +1 -1
- requirements.txt +2 -0
app.py
CHANGED
@@ -3,6 +3,7 @@ import re
|
|
3 |
import gradio as gr
|
4 |
from functions.punctuation import punctuate
|
5 |
from functions.model_infer import predict_from_document
|
|
|
6 |
|
7 |
|
8 |
title = "sponsoredBye - never listen to sponsors again"
|
@@ -12,16 +13,18 @@ article = "Check out [the original Rick and Morty Bot](https://huggingface.co/sp
|
|
12 |
|
13 |
def pipeline(video_url):
|
14 |
video_id = video_url.split("?v=")[-1]
|
15 |
-
punctuated_text = punctuate(video_id)
|
16 |
sentences = re.split(r"[\.\!\?]\s", punctuated_text)
|
17 |
classification, probs = predict_from_document(sentences)
|
18 |
# return punctuated_text
|
|
|
19 |
return [
|
20 |
{
|
21 |
"start": "12:05",
|
22 |
"end": "12:52",
|
23 |
"classification": str(classification),
|
24 |
"probabilities": probs,
|
|
|
25 |
}
|
26 |
]
|
27 |
|
|
|
3 |
import gradio as gr
|
4 |
from functions.punctuation import punctuate
|
5 |
from functions.model_infer import predict_from_document
|
6 |
+
from functions.convert_time import match_mask_and_transcript
|
7 |
|
8 |
|
9 |
title = "sponsoredBye - never listen to sponsors again"
|
|
|
13 |
|
14 |
def pipeline(video_url):
|
15 |
video_id = video_url.split("?v=")[-1]
|
16 |
+
punctuated_text, transcript = punctuate(video_id)
|
17 |
sentences = re.split(r"[\.\!\?]\s", punctuated_text)
|
18 |
classification, probs = predict_from_document(sentences)
|
19 |
# return punctuated_text
|
20 |
+
times = match_mask_and_transcript(sentences, transcript, classification)
|
21 |
return [
|
22 |
{
|
23 |
"start": "12:05",
|
24 |
"end": "12:52",
|
25 |
"classification": str(classification),
|
26 |
"probabilities": probs,
|
27 |
+
"times": times,
|
28 |
}
|
29 |
]
|
30 |
|
functions/convert_time.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from thefuzz import fuzz
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def match_mask_and_transcript(split_punct, transcript, classification):
|
7 |
+
"""
|
8 |
+
Input:
|
9 |
+
split_punct: the punctuated text, split on ?/!/.\s,
|
10 |
+
transcript: original transcript with timestamps
|
11 |
+
classification: classification object (list of numbers 0,1)
|
12 |
+
Output: times
|
13 |
+
"""
|
14 |
+
|
15 |
+
# Get the sponsored part
|
16 |
+
sponsored_segment = []
|
17 |
+
for i, val in enumerate(classification):
|
18 |
+
if val == 1:
|
19 |
+
sponsored_segment.append(split_punct[i])
|
20 |
+
|
21 |
+
segment = " ".join(sponsored_segment)
|
22 |
+
sim_scores = list()
|
23 |
+
|
24 |
+
# Check the similarity scores between the sponsored part and the transcript parts
|
25 |
+
for elem in transcript:
|
26 |
+
sim_scores.append(fuzz.partial_ratio(segment, elem["text"]))
|
27 |
+
|
28 |
+
# Get the scores and check if they are above mean + 2*stdev
|
29 |
+
scores = np.array(sim_scores)
|
30 |
+
timestamp_mask = (scores > np.mean(scores) + np.std(scores) * 2).astype(int)
|
31 |
+
timestamps = [
|
32 |
+
(transcript[i]["start"], transcript[i]["duration"])
|
33 |
+
for i, elem in enumerate(timestamp_mask)
|
34 |
+
if elem == 1
|
35 |
+
]
|
36 |
+
|
37 |
+
# Get the timestamp segments
|
38 |
+
times = []
|
39 |
+
current = -1
|
40 |
+
current_time = 0
|
41 |
+
for elem in timestamps:
|
42 |
+
# Threshold of 5 to see if it is a jump to another segment (also to make sure smaller segments are added together
|
43 |
+
if elem[0] > (current_time + 5):
|
44 |
+
current += 1
|
45 |
+
times.append((elem[0], elem[0] + elem[1]))
|
46 |
+
current_time = elem[0] + elem[1]
|
47 |
+
else:
|
48 |
+
times[current] = (times[current][0], elem[0] + elem[1])
|
49 |
+
current_time = elem[0] + elem[1]
|
50 |
+
return times
|
functions/model_infer.py
CHANGED
@@ -41,6 +41,6 @@ def predict_from_document(sentences):
|
|
41 |
# Set the prediction threshold to 0.8 instead of 0.5, now use mean
|
42 |
output = (
|
43 |
prediction.flatten()[: len(sentences)]
|
44 |
-
>= np.mean(prediction) + np.
|
45 |
).astype(int)
|
46 |
return output, prediction.flatten()[: len(sentences)]
|
|
|
41 |
# Set the prediction threshold to 0.8 instead of 0.5, now use mean
|
42 |
output = (
|
43 |
prediction.flatten()[: len(sentences)]
|
44 |
+
>= np.mean(prediction) + np.std(prediction) * 2
|
45 |
).astype(int)
|
46 |
return output, prediction.flatten()[: len(sentences)]
|
functions/punctuation.py
CHANGED
@@ -55,4 +55,4 @@ def punctuate(video_id):
|
|
55 |
) # Get the transcript from the YoutubeTranscriptApi
|
56 |
resp = query_punctuation(splits) # Get the response from the Inference API
|
57 |
punctuated_transcript = parse_output(resp, splits)
|
58 |
-
return punctuated_transcript
|
|
|
55 |
) # Get the transcript from the YoutubeTranscriptApi
|
56 |
resp = query_punctuation(splits) # Get the response from the Inference API
|
57 |
punctuated_transcript = parse_output(resp, splits)
|
58 |
+
return punctuated_transcript, transcript
|
requirements.txt
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
youtube_transcript_api
|
|
|
|
|
2 |
tensorflow==2.15
|
3 |
keras
|
4 |
keras-nlp
|
|
|
1 |
youtube_transcript_api
|
2 |
+
thefuzz
|
3 |
+
numpy
|
4 |
tensorflow==2.15
|
5 |
keras
|
6 |
keras-nlp
|