Spaces:
Sleeping
Sleeping
File size: 5,354 Bytes
b7b7347 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import preprocess
from dataclasses import dataclass, field
@dataclass
class SegmentationArguments:
pause_threshold: int = field(default=2.5, metadata={
'help': 'When the time between words is greater than pause threshold, force into a new segment'})
def get_overlapping_chunks_of_tokens(tokens, size, overlap):
for i in range(0, len(tokens), size-overlap+1):
yield tokens[i:i+size]
# Generate up to SAFETY_TOKENS_PERCENTAGE*max_tokens tokens
MIN_SAFETY_TOKENS = 8
SAFETY_TOKENS_PERCENTAGE = 0.9765625
# e.g. 512 -> 500, 768 -> 750
# TODO play around with this?
OVERLAP_TOKEN_PERCENTAGE = 0.5 # 0.25
def add_labels_to_words(words, sponsor_segments):
for sponsor_segment in sponsor_segments:
for w in extract_segment(words, sponsor_segment['start'], sponsor_segment['end']):
w['category'] = sponsor_segment['category']
return words
def generate_labelled_segments(words, tokenizer, segmentation_args, sponsor_segments):
segments = generate_segments(words, tokenizer, segmentation_args)
labelled_segments = list(
map(lambda x: add_labels_to_words(x, sponsor_segments), segments))
return labelled_segments
def word_start(word):
return word['start']
def word_end(word):
return word.get('end', word['start'])
def generate_segments(words, tokenizer, segmentation_args):
cleaned_words_list = []
for w in words:
w['cleaned'] = preprocess.clean_text(w['text'])
cleaned_words_list.append(w['cleaned'])
# Get lengths of tokenized words
num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False,
truncation=True, return_attention_mask=False, return_length=True).length
first_pass_segments = []
for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
word['num_tokens'] = num_tokens
# Add new segment
if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
first_pass_segments.append([word])
else: # Add to current segment
first_pass_segments[-1].append(word)
max_q_size = round(SAFETY_TOKENS_PERCENTAGE * tokenizer.model_max_length)
buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size # tokenizer.model_max_length
# In second pass, we split those segments if too big
second_pass_segments = []
for segment in first_pass_segments:
current_segment_num_tokens = 0
current_segment = []
after_split_segments = []
for word in segment:
new_seg = current_segment_num_tokens + \
word['num_tokens'] >= max_q_size
if new_seg:
# Adding this token would make it have too many tokens
# We save this batch and create new
after_split_segments.append(current_segment)
# Add tokens to current segment
current_segment.append(word)
current_segment_num_tokens += word['num_tokens']
if not new_seg:
continue
# Just created a new segment, so we remove until we only have buffer_size tokens
last_index = 0
while current_segment_num_tokens > buffer_size and current_segment:
current_segment_num_tokens -= current_segment[last_index]['num_tokens']
last_index += 1
current_segment = current_segment[last_index:]
if current_segment: # Add remaining segment
after_split_segments.append(current_segment)
# TODO if len(after_split_segments) > 1, a split occurred
second_pass_segments.extend(after_split_segments)
# Cleaning up, delete 'num_tokens' from each word
for word in words:
word.pop('num_tokens', None)
return second_pass_segments
def extract_segment(words, start, end, map_function=None):
"""Extracts all words with time in [start, end]"""
if words is None:
words = []
a = max(binary_search_below(words, 0, len(words), start), 0)
b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
to_transform = map_function is not None and callable(map_function)
return [
map_function(words[i]) if to_transform else words[i] for i in range(a, b)
]
def avg(*items):
return sum(items)/len(items)
def binary_search_below(transcript, start_index, end_index, time):
if start_index >= end_index:
return end_index
middle_index = (start_index + end_index) // 2
middle = transcript[middle_index]
middle_time = avg(word_start(middle), word_end(middle))
if time <= middle_time:
return binary_search_below(transcript, start_index, middle_index, time)
else:
return binary_search_below(transcript, middle_index + 1, end_index, time)
def binary_search_above(transcript, start_index, end_index, time):
if start_index >= end_index:
return end_index
middle_index = (start_index + end_index + 1) // 2
middle = transcript[middle_index]
middle_time = avg(word_start(middle), word_end(middle))
if time >= middle_time:
return binary_search_above(transcript, middle_index, end_index, time)
else:
return binary_search_above(transcript, start_index, middle_index - 1, time)
|