Joshua Lochner commited on
Commit
bd6fd75
·
1 Parent(s): 366d154

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +314 -66
pipeline.py CHANGED
@@ -1,78 +1,326 @@
 
1
  import json
 
 
 
 
 
 
 
2
  from typing import Any, Dict, List
3
 
4
- import tensorflow as tf
5
- from tensorflow import keras
6
- import base64
7
- import io
8
- import os
9
- import numpy as np
10
  from PIL import Image
11
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class PreTrainedPipeline():
15
  def __init__(self, path: str):
16
  # load the model
17
- self.model = keras.models.load_model(os.path.join(path, "tf_model.h5"))
 
 
 
18
 
19
- def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:
 
 
20
 
21
- # convert img to numpy array, resize and normalize to make the prediction
22
- img = np.array(inputs)
23
-
24
- im = tf.image.resize(img, (128, 128))
25
- im = tf.cast(im, tf.float32) / 255.0
26
- pred_mask = self.model.predict(im[tf.newaxis, ...])
27
-
28
- # take the best performing class for each pixel
29
- # the output of argmax looks like this [[1, 2, 0], ...]
30
- pred_mask_arg = tf.argmax(pred_mask, axis=-1)
31
-
32
- labels = []
33
-
34
- # convert the prediction mask into binary masks for each class
35
- binary_masks = {}
36
- mask_codes = {}
37
-
38
- # when we take tf.argmax() over pred_mask, it becomes a tensor object
39
- # the shape becomes TensorShape object, looking like this TensorShape([128])
40
- # we need to take get shape, convert to list and take the best one
41
-
42
- rows = pred_mask_arg[0][1].get_shape().as_list()[0]
43
- cols = pred_mask_arg[0][2].get_shape().as_list()[0]
44
-
45
- for cls in range(pred_mask.shape[-1]):
46
-
47
- binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
48
-
49
- for row in range(rows):
50
-
51
- for col in range(cols):
52
-
53
- if pred_mask_arg[0][row][col] == cls:
54
-
55
- binary_masks[f"mask_{cls}"][row][col] = 1
56
- else:
57
- binary_masks[f"mask_{cls}"][row][col] = 0
58
-
59
- mask = binary_masks[f"mask_{cls}"]
60
- mask *= 255
61
- img = Image.fromarray(mask.astype(np.int8), mode="L")
62
-
63
- # we need to make it readable for the widget
64
- with io.BytesIO() as out:
65
- img.save(out, format="PNG")
66
- png_string = out.getvalue()
67
- mask = base64.b64encode(png_string).decode("utf-8")
68
-
69
- mask_codes[f"mask_{cls}"] = mask
70
-
71
-
72
- # widget needs the below format, for each class we return label and mask string
73
- labels.append({
74
- "label": f"LABEL_{cls}",
75
- "mask": mask_codes[f"mask_{cls}"],
76
- "score": 1.0,
77
- })
78
- return labels
 
1
+ import youtube_transcript_api2
2
  import json
3
+ import re
4
+ import requests
5
+ from transformers import (
6
+ AutoModelForSequenceClassification,
7
+ AutoTokenizer,
8
+ TextClassificationPipeline,
9
+ )
10
  from typing import Any, Dict, List
11
 
 
 
 
 
 
 
12
  from PIL import Image
13
 
14
+ CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
15
 
16
+ PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
17
+ PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
18
+
19
+ NUM_DECIMALS = 3
20
+
21
+ # https://www.fincher.org/Utilities/CountryLanguageList.shtml
22
+ # https://lingohub.com/developers/supported-locales/language-designators-with-regions
23
+ LANGUAGE_PREFERENCE_LIST = ['en-GB', 'en-US', 'en-CA', 'en-AU', 'en-NZ', 'en-ZA',
24
+ 'en-IE', 'en-IN', 'en-JM', 'en-BZ', 'en-TT', 'en-PH', 'en-ZW',
25
+ 'en']
26
+
27
+
28
+ def parse_transcript_json(json_data, granularity):
29
+ assert json_data['wireMagic'] == 'pb3'
30
+
31
+ assert granularity in ('word', 'chunk')
32
+
33
+ # TODO remove bracketed words?
34
+ # (kiss smacks)
35
+ # (upbeat music)
36
+ # [text goes here]
37
+
38
+ # Some manual transcripts aren't that well formatted... but do have punctuation
39
+ # https://www.youtube.com/watch?v=LR9FtWVjk2c
40
+
41
+ parsed_transcript = []
42
+
43
+ events = json_data['events']
44
+
45
+ for event_index, event in enumerate(events):
46
+ segments = event.get('segs')
47
+ if not segments:
48
+ continue
49
+
50
+ # This value is known (when phrase appears on screen)
51
+ start_ms = event['tStartMs']
52
+ total_characters = 0
53
+
54
+ new_segments = []
55
+ for seg in segments:
56
+ # Replace \n, \t, etc. with space
57
+ text = ' '.join(seg['utf8'].split())
58
+
59
+ # Remove zero-width spaces and strip trailing and leading whitespace
60
+ text = text.replace('\u200b', '').replace('\u200c', '').replace(
61
+ '\u200d', '').replace('\ufeff', '').strip()
62
+
63
+ # Alternatively,
64
+ # text = text.encode('ascii', 'ignore').decode()
65
+
66
+ # Needed for auto-generated transcripts
67
+ text = text.replace(PROFANITY_RAW, PROFANITY_CONVERTED)
68
+
69
+ if not text:
70
+ continue
71
+
72
+ offset_ms = seg.get('tOffsetMs', 0)
73
+
74
+ new_segments.append({
75
+ 'text': text,
76
+ 'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS)
77
+ })
78
+
79
+ total_characters += len(text)
80
+
81
+ if not new_segments:
82
+ continue
83
+
84
+ if event_index < len(events) - 1:
85
+ next_start_ms = events[event_index + 1]['tStartMs']
86
+ total_event_duration_ms = min(
87
+ event.get('dDurationMs', float('inf')), next_start_ms - start_ms)
88
+ else:
89
+ total_event_duration_ms = event.get('dDurationMs', 0)
90
+
91
+ # Ensure duration is non-negative
92
+ total_event_duration_ms = max(total_event_duration_ms, 0)
93
+
94
+ avg_seconds_per_character = (
95
+ total_event_duration_ms/total_characters)/1000
96
+
97
+ num_char_count = 0
98
+ for seg_index, seg in enumerate(new_segments):
99
+ num_char_count += len(seg['text'])
100
+
101
+ # Estimate segment end
102
+ seg_end = seg['start'] + \
103
+ (num_char_count * avg_seconds_per_character)
104
+
105
+ if seg_index < len(new_segments) - 1:
106
+ # Do not allow longer than next
107
+ seg_end = min(seg_end, new_segments[seg_index+1]['start'])
108
+
109
+ seg['end'] = round(seg_end, NUM_DECIMALS)
110
+ parsed_transcript.append(seg)
111
+
112
+ final_parsed_transcript = []
113
+ for i in range(len(parsed_transcript)):
114
+
115
+ word_level = granularity == 'word'
116
+ if word_level:
117
+ split_text = parsed_transcript[i]['text'].split()
118
+ elif granularity == 'chunk':
119
+ # Split on space after punctuation
120
+ split_text = re.split(
121
+ r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text'])
122
+ if len(split_text) == 1:
123
+ split_on_whitespace = parsed_transcript[i]['text'].split()
124
+
125
+ if len(split_on_whitespace) >= 8: # Too many words
126
+ # Rather split on whitespace instead of punctuation
127
+ split_text = split_on_whitespace
128
+ else:
129
+ word_level = True
130
+ else:
131
+ raise ValueError('Unknown granularity')
132
+
133
+ segment_end = parsed_transcript[i]['end']
134
+ if i < len(parsed_transcript) - 1:
135
+ segment_end = min(segment_end, parsed_transcript[i+1]['start'])
136
+
137
+ segment_duration = segment_end - parsed_transcript[i]['start']
138
+
139
+ num_chars_in_text = sum(map(len, split_text))
140
+
141
+ num_char_count = 0
142
+ current_offset = 0
143
+ for s in split_text:
144
+ num_char_count += len(s)
145
+
146
+ next_offset = (num_char_count/num_chars_in_text) * segment_duration
147
+
148
+ word_start = round(
149
+ parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS)
150
+ word_end = round(
151
+ parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS)
152
+
153
+ # Make the reasonable assumption that min wps is 1.5
154
+ final_parsed_transcript.append({
155
+ 'text': s,
156
+ 'start': word_start,
157
+ 'end': min(word_end, word_start + 1.5) if word_level else word_end
158
+ })
159
+ current_offset = next_offset
160
+
161
+ return final_parsed_transcript
162
+
163
+
164
+ def list_transcripts(video_id):
165
+ try:
166
+ return youtube_transcript_api2.YouTubeTranscriptApi.list_transcripts(video_id)
167
+ except json.decoder.JSONDecodeError:
168
+ return None
169
+
170
+
171
+ WORDS_TO_REMOVE = [
172
+ '[Music]'
173
+ '[Applause]'
174
+ '[Laughter]'
175
+ ]
176
+
177
+
178
+ def get_words(video_id, transcript_type='auto', fallback='manual', filter_words_to_remove=True, granularity='word'):
179
+ """Get parsed video transcript with caching system
180
+ returns None if not processed yet and process is False
181
+ """
182
+
183
+ raw_transcript_json = None
184
+ try:
185
+ transcript_list = list_transcripts(video_id)
186
+
187
+ if transcript_list is not None:
188
+ if transcript_type == 'manual':
189
+ ts = transcript_list.find_manually_created_transcript(
190
+ LANGUAGE_PREFERENCE_LIST)
191
+ else:
192
+ ts = transcript_list.find_generated_transcript(
193
+ LANGUAGE_PREFERENCE_LIST)
194
+ raw_transcript = ts._http_client.get(
195
+ f'{ts._url}&fmt=json3').content
196
+ if raw_transcript:
197
+ raw_transcript_json = json.loads(raw_transcript)
198
+
199
+ except (youtube_transcript_api2.TooManyRequests, youtube_transcript_api2.YouTubeRequestFailed):
200
+ raise # Cannot recover from these errors and do not mark as empty transcript
201
+
202
+ except requests.exceptions.RequestException: # Can recover
203
+ return get_words(video_id, transcript_type, fallback, granularity)
204
+
205
+ except youtube_transcript_api2.CouldNotRetrieveTranscript: # Retrying won't solve
206
+ pass # Mark as empty transcript
207
+
208
+ except json.decoder.JSONDecodeError:
209
+ return get_words(video_id, transcript_type, fallback, granularity)
210
+
211
+ if not raw_transcript_json and fallback is not None:
212
+ return get_words(video_id, transcript_type=fallback, fallback=None, granularity=granularity)
213
+
214
+ if raw_transcript_json:
215
+ processed_transcript = parse_transcript_json(
216
+ raw_transcript_json, granularity)
217
+ if filter_words_to_remove:
218
+ processed_transcript = list(
219
+ filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript))
220
+ else:
221
+ processed_transcript = raw_transcript_json # Either None or []
222
+
223
+ return processed_transcript
224
+
225
+
226
+ def word_start(word):
227
+ return word['start']
228
+
229
+
230
+ def word_end(word):
231
+ return word.get('end', word['start'])
232
+
233
+
234
+ def extract_segment(words, start, end, map_function=None):
235
+ """Extracts all words with time in [start, end]"""
236
+
237
+ a = max(binary_search_below(words, 0, len(words), start), 0)
238
+ b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
239
+
240
+ to_transform = map_function is not None and callable(map_function)
241
+
242
+ return [
243
+ map_function(words[i]) if to_transform else words[i] for i in range(a, b)
244
+ ]
245
+
246
+
247
+ def avg(*items):
248
+ return sum(items)/len(items)
249
+
250
+
251
+ def binary_search_below(transcript, start_index, end_index, time):
252
+ if start_index >= end_index:
253
+ return end_index
254
+
255
+ middle_index = (start_index + end_index) // 2
256
+ middle = transcript[middle_index]
257
+ middle_time = avg(word_start(middle), word_end(middle))
258
+
259
+ if time <= middle_time:
260
+ return binary_search_below(transcript, start_index, middle_index, time)
261
+ else:
262
+ return binary_search_below(transcript, middle_index + 1, end_index, time)
263
+
264
+
265
+ def binary_search_above(transcript, start_index, end_index, time):
266
+ if start_index >= end_index:
267
+ return end_index
268
+
269
+ middle_index = (start_index + end_index + 1) // 2
270
+ middle = transcript[middle_index]
271
+ middle_time = avg(word_start(middle), word_end(middle))
272
+
273
+ if time >= middle_time:
274
+ return binary_search_above(transcript, middle_index, end_index, time)
275
+ else:
276
+ return binary_search_above(transcript, start_index, middle_index - 1, time)
277
+
278
+
279
+ class SponsorBlockClassificationPipeline(TextClassificationPipeline):
280
+ def __init__(self, model, tokenizer):
281
+ super().__init__(model=model, tokenizer=tokenizer, return_all_scores=True)
282
+
283
+ def preprocess(self, video, **tokenizer_kwargs):
284
+
285
+ words = get_words(video['video_id'])
286
+ segment_words = extract_segment(words, video['start'], video['end'])
287
+ text = ' '.join(x['text'] for x in segment_words)
288
+
289
+ model_inputs = self.tokenizer(
290
+ text, return_tensors=self.framework, **tokenizer_kwargs)
291
+ return {'video': video, 'model_inputs': model_inputs}
292
+
293
+ def _forward(self, data):
294
+ model_outputs = self.model(**data['model_inputs'])
295
+ return {'video': data['video'], 'model_outputs': model_outputs}
296
+
297
+ def postprocess(self, data, function_to_apply=None, return_all_scores=False):
298
+ model_outputs = data['model_outputs']
299
+
300
+ results = super().postprocess(model_outputs, function_to_apply, return_all_scores)
301
+
302
+ for result in results:
303
+ result['label_text'] = CATEGORIES[result['label']]
304
+
305
+ return results # {**data['video'], 'result': results}
306
 
307
  class PreTrainedPipeline():
308
  def __init__(self, path: str):
309
  # load the model
310
+ self.model = AutoModelForSequenceClassification.from_pretrained(path)
311
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
312
+ self.pipeline = SponsorBlockClassificationPipeline(
313
+ model=self.model, tokenizer=self.tokenizer)
314
 
315
+ def __call__(self, inputs: str) -> List[Dict[str, Any]]:
316
+ json_data = json.loads(inputs)
317
+ return self.pipeline(json_data)
318
 
319
+ def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:
320
+ json_data = [{
321
+ 'video_id': 'pqh4LfPeCYs',
322
+ 'start': 835.933,
323
+ 'end': 927.581,
324
+ 'category': 'sponsor'
325
+ }]
326
+ return self.pipeline(json_data)