File size: 5,594 Bytes
d8e813d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from collections import deque
import streamlit as st
import torch
from streamlit_player import st_player
from transformers import AutoModelForCTC, Wav2Vec2Processor
from streaming import ffmpeg_stream

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
player_options = {
    "events": ["onProgress"],
    "progress_interval": 200,
    "volume": 1.0,
    "playing": True,
    "loop": False,
    "controls": False,
    "muted": False,
    "config": {"youtube": {"playerVars": {"start": 1}}},
}

# disable rapid fading in and out on `st.code` updates
st.markdown("<style>.element-container{opacity:1 !important}</style>", unsafe_allow_html=True)

@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
def load_model(model_path="facebook/wav2vec2-large-robust-ft-swbd-300h"):
    processor = Wav2Vec2Processor.from_pretrained(model_path)
    model = AutoModelForCTC.from_pretrained(model_path).to(device)
    return processor, model
    
processor, model = load_model()

def stream_text(url, chunk_duration_ms, pad_duration_ms):
    sampling_rate = processor.feature_extractor.sampling_rate

    # calculate the length of logits to cut from the sides of the output to account for input padding
    output_pad_len = model._get_feat_extract_output_lengths(int(sampling_rate * pad_duration_ms / 1000))

    # define the audio chunk generator
    stream = ffmpeg_stream(url, sampling_rate, chunk_duration_ms=chunk_duration_ms, pad_duration_ms=pad_duration_ms)

    leftover_text = ""
    for i, chunk in enumerate(stream):
        input_values = processor(chunk, sampling_rate=sampling_rate, return_tensors="pt").input_values

        with torch.no_grad():
            logits = model(input_values.to(device)).logits[0]
            if i > 0:
                logits = logits[output_pad_len : len(logits) - output_pad_len]
            else:  # don't count padding at the start of the clip
                logits = logits[: len(logits) - output_pad_len]

            predicted_ids = torch.argmax(logits, dim=-1).cpu().tolist()
            if processor.decode(predicted_ids).strip():
                leftover_ids = processor.tokenizer.encode(leftover_text)
                # concat the last word (or its part) from the last frame with the current text
                text = processor.decode(leftover_ids + predicted_ids)
                # don't return the last word in case it's just partially recognized
                text, leftover_text = text.rsplit(" ", 1)
                yield text
            else:
                yield leftover_text
                leftover_text = ""
    yield leftover_text

def main():
    state = st.session_state
    st.header("Video ASR Streamlit from Youtube Link")

    with st.form(key="inputs_form"):
    
        # Our worlds best teachers on subjects of AI, Cognitive, Neuroscience for our Behavioral and Medical Health
        ytJoschaBach="https://youtu.be/cC1HszE5Hcw?list=PLHgX2IExbFouJoqEr8JMF5MbZSbyC91-L&t=8984"
        ytSamHarris="https://www.youtube.com/watch?v=4dC_nRYIDZU&list=PLHgX2IExbFouJoqEr8JMF5MbZSbyC91-L&index=2"
        ytJohnAbramson="https://www.youtube.com/watch?v=arrokG3wCdE&list=PLHgX2IExbFouJoqEr8JMF5MbZSbyC91-L&index=3"
        ytElonMusk="https://www.youtube.com/watch?v=DxREm3s1scA&list=PLHgX2IExbFouJoqEr8JMF5MbZSbyC91-L&index=4"
        ytJeffreyShainline="https://www.youtube.com/watch?v=EwueqdgIvq4&list=PLHgX2IExbFouJoqEr8JMF5MbZSbyC91-L&index=5"
        ytJeffHawkins="https://www.youtube.com/watch?v=Z1KwkpTUbkg&list=PLHgX2IExbFouJoqEr8JMF5MbZSbyC91-L&index=6"
        ytSamHarris="https://youtu.be/Ui38ZzTymDY?list=PLHgX2IExbFouJoqEr8JMF5MbZSbyC91-L"
        ytSamHarris="https://youtu.be/4dC_nRYIDZU?list=PLHgX2IExbFouJoqEr8JMF5MbZSbyC91-L&t=7809"
        ytSamHarris="https://youtu.be/4dC_nRYIDZU?list=PLHgX2IExbFouJoqEr8JMF5MbZSbyC91-L&t=7809"
        ytSamHarris="https://youtu.be/4dC_nRYIDZU?list=PLHgX2IExbFouJoqEr8JMF5MbZSbyC91-L&t=7809"
        ytTimelapseAI="https://www.youtube.com/watch?v=63yr9dlI0cU&list=PLHgX2IExbFovQybyfltywXnqZi5YvaSS-"
        state.youtube_url = st.text_input("YouTube URL", ytTimelapseAI)
        
        
        state.chunk_duration_ms = st.slider("Audio chunk duration (ms)", 2000, 10000, 3000, 100)
        state.pad_duration_ms = st.slider("Padding duration (ms)", 100, 5000, 1000, 100)
        submit_button = st.form_submit_button(label="Submit")

    if submit_button or "asr_stream" not in state:
        # a hack to update the video player on value changes
        state.youtube_url = (
            state.youtube_url.split("&hash=")[0]
            + f"&hash={state.chunk_duration_ms}-{state.pad_duration_ms}"
        )
        state.asr_stream = stream_text(
            state.youtube_url, state.chunk_duration_ms, state.pad_duration_ms
        )
        state.chunks_taken = 0
        
        
        state.lines = deque([], maxlen=100)  # limit to the last n lines of subs
        

    player = st_player(state.youtube_url, **player_options, key="youtube_player")

    if "asr_stream" in state and player.data and player.data["played"] < 1.0:
        # check how many seconds were played, and if more than processed - write the next text chunk
        processed_seconds = state.chunks_taken * (state.chunk_duration_ms / 1000)
        if processed_seconds < player.data["playedSeconds"]:
            text = next(state.asr_stream)
            state.lines.append(text)
            state.chunks_taken += 1
    if "lines" in state:
        # print the lines of subs
        st.code("\n".join(state.lines))


if __name__ == "__main__":
    main()