File size: 10,298 Bytes
89e2d7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
"""
https://huggingface.co/tomiwa1a/video-search
"""
from typing import Dict

from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import whisper
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
import pytube
import time


class EndpointHandler():
    # load the model
    WHISPER_MODEL_NAME = "tiny.en"
    SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
    QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
    SUMMARIZER_MODEL_NAME = "philschmid/bart-large-cnn-samsum"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device_number = 0 if torch.cuda.is_available() else -1

    def __init__(self, path=""):

        device = "cuda" if torch.cuda.is_available() else "cpu"
        device_number = 0 if torch.cuda.is_available() else -1
        print(f'whisper and question_answer_model will use: {device}')
        print(f'whisper and question_answer_model will use device_number: {device_number}')

        t0 = time.time()
        self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device_number)
        t1 = time.time()

        total = t1 - t0
        print(f'Finished loading whisper_model in {total} seconds')

        t0 = time.time()
        self.sentence_transformer_model = SentenceTransformer(self.SENTENCE_TRANSFORMER_MODEL_NAME)
        t1 = time.time()

        total = t1 - t0
        print(f'Finished loading sentence_transformer_model in {total} seconds')

        t0 = time.time()
        self.summarizer = pipeline("summarization", model=self.SUMMARIZER_MODEL_NAME, device=device_number)
        t1 = time.time()

        total = t1 - t0
        print(f'Finished loading summarizer in {total} seconds')

        self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME)
        t0 = time.time()
        self.question_answer_model = AutoModelForSeq2SeqLM.from_pretrained \
            (self.QUESTION_ANSWER_MODEL_NAME).to(device_number)
        t1 = time.time()
        total = t1 - t0
        print(f'Finished loading question_answer_model in {total} seconds')

    def __call__(self, data: Dict[str, str]) -> Dict:
        """
        Args:
            data (:obj:):
                includes the URL to video for transcription
        Return:
            A :obj:`dict`:. transcribed dict
        """
        # process input
        print('data', data)

        if "inputs" not in data:
            raise Exception(f"data is missing 'inputs' key which  EndpointHandler expects. Received: {data}"
                            f" See: https://huggingface.co/docs/inference-endpoints/guides/custom_handler#2-create-endpointhandler-cp")
        video_url = data.pop("video_url", None)
        query = data.pop("query", None)
        long_form_answer = data.pop("long_form_answer", None)
        summarize = data.pop("summarize", False)
        encoded_segments = {}
        if video_url:
            video_with_transcript = self.transcribe_video(video_url)
            video_with_transcript['transcript']['transcription_source'] = f"whisper_{self.WHISPER_MODEL_NAME}"
            encode_transcript = data.pop("encode_transcript", True)
            if encode_transcript:
                encoded_segments = self.combine_transcripts(video_with_transcript)
                encoded_segments = {
                    "encoded_segments": self.encode_sentences(encoded_segments)
                }
            return {
                **video_with_transcript,
                **encoded_segments
            }
        elif summarize:
            summary = self.summarize_video(data["segments"])
            return {"summary": summary}
        elif query:
            if long_form_answer:
                context = data.pop("context", None)
                answer = self.generate_answer(query, context)
                response = {
                    "answer": answer
                }

                return response
            else:
                query = [{"text": query, "id": ""}] if isinstance(query, str) else query
                encoded_segments = self.encode_sentences(query)

                response = {
                    "encoded_segments": encoded_segments
                }

                return response

        else:
            return {
                "error": "'video_url' or 'query' must be provided"
            }

    def transcribe_video(self, video_url):
        decode_options = {
            # Set language to None to support multilingual,
            # but it will take longer to process while it detects the language.
            # Realized this by running in verbose mode and seeing how much time
            # was spent on the decoding language step
            "language": "en",
            "verbose": True
        }
        yt = pytube.YouTube(video_url)
        video_info = {
            'id': yt.video_id,
            'thumbnail': yt.thumbnail_url,
            'title': yt.title,
            'views': yt.views,
            'length': yt.length,
            # Althhough, this might seem redundant since we already have id
            # but it allows the link to the video be accessed in 1-click in the API response
            'url': f"https://www.youtube.com/watch?v={yt.video_id}"
        }
        stream = yt.streams.filter(only_audio=True)[0]
        path_to_audio = f"{yt.video_id}.mp3"
        stream.download(filename=path_to_audio)
        t0 = time.time()
        transcript = self.whisper_model.transcribe(path_to_audio, **decode_options)
        t1 = time.time()
        for segment in transcript['segments']:
            # Remove the tokens array, it makes the response too verbose
            segment.pop('tokens', None)

        total = t1 - t0
        print(f'Finished transcription in {total} seconds')

        # postprocess the prediction
        return {"transcript": transcript, 'video': video_info}

    def encode_sentences(self, transcripts, batch_size=64):
        """
        Encoding all of our segments at once or storing them locally would require too much compute or memory.
        So we do it in batches of 64
        :param transcripts:
        :param batch_size:
        :return:
        """
        # loop through in batches of 64
        all_batches = []
        for i in tqdm(range(0, len(transcripts), batch_size)):
            # find end position of batch (for when we hit end of data)
            i_end = min(len(transcripts), i + batch_size)
            # extract the metadata like text, start/end positions, etc
            batch_meta = [{
                **row
            } for row in transcripts[i:i_end]]
            # extract only text to be encoded by embedding model
            batch_text = [
                row['text'] for row in batch_meta
            ]
            # create the embedding vectors
            batch_vectors = self.sentence_transformer_model.encode(batch_text).tolist()

            batch_details = [
                {
                    **batch_meta[x],
                    'vectors': batch_vectors[x]
                } for x in range(0, len(batch_meta))
            ]
            all_batches.extend(batch_details)

        return all_batches

    def summarize_video(self, segments):
        for index, segment in enumerate(segments):
            segment['summary'] = self.summarizer(segment['text'])
            segment['summary'] = segment['summary'][0]['summary_text']
            print('index', index)
            print('length', segment['length'])
            print('text', segment['text'])
            print('summary', segment['summary'])

        return segments

    def generate_answer(self, query, documents):

        # concatenate question and support documents into BART input
        conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
        query_and_docs = "question: {} context: {}".format(query, conditioned_doc)

        model_input = self.question_answer_tokenizer(query_and_docs, truncation=False, padding=True,
                                                     return_tensors="pt")

        generated_answers_encoded = self.question_answer_model.generate(
            input_ids=model_input["input_ids"].to(self.device),
            attention_mask=model_input["attention_mask"].to(self.device),
            min_length=64,
            max_length=256,
            do_sample=False,
            early_stopping=True,
            num_beams=8,
            temperature=1.0,
            top_k=None,
            top_p=None,
            eos_token_id=self.question_answer_tokenizer.eos_token_id,
            no_repeat_ngram_size=3,
            num_return_sequences=1)
        answer = self.question_answer_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
                                                             clean_up_tokenization_spaces=True)
        return answer

    @staticmethod
    def combine_transcripts(video, window=6, stride=3):
        """

        :param video:
        :param window: number of sentences to combine
        :param stride: number of sentences to 'stride' over, used to create overlap
        :return:
        """
        new_transcript_segments = []

        video_info = video['video']
        transcript_segments = video['transcript']['segments']
        for i in tqdm(range(0, len(transcript_segments), stride)):
            i_end = min(len(transcript_segments), i + window)
            text = ' '.join(transcript['text']
                            for transcript in
                            transcript_segments[i:i_end])
            # TODO: Should int (float to seconds) conversion happen at the API level?
            start = int(transcript_segments[i]['start'])
            end = int(transcript_segments[i]['end'])
            new_transcript_segments.append({
                **video_info,
                **{
                    'start': start,
                    'end': end,
                    'title': video_info['title'],
                    'text': text,
                    'id': f"{video_info['id']}-t{start}",
                    'url': f"https://youtu.be/{video_info['id']}?t={start}",
                    'video_id': video_info['id'],
                }
            })
        return new_transcript_segments