File size: 5,354 Bytes
b7b7347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import preprocess
from dataclasses import dataclass, field


@dataclass
class SegmentationArguments:
    pause_threshold: int = field(default=2.5, metadata={
        'help': 'When the time between words is greater than pause threshold, force into a new segment'})


def get_overlapping_chunks_of_tokens(tokens, size, overlap):
    for i in range(0, len(tokens), size-overlap+1):
        yield tokens[i:i+size]


# Generate up to SAFETY_TOKENS_PERCENTAGE*max_tokens tokens
MIN_SAFETY_TOKENS = 8
SAFETY_TOKENS_PERCENTAGE = 0.9765625
# e.g. 512 -> 500, 768 -> 750


# TODO play around with this?
OVERLAP_TOKEN_PERCENTAGE = 0.5  # 0.25


def add_labels_to_words(words, sponsor_segments):

    for sponsor_segment in sponsor_segments:
        for w in extract_segment(words, sponsor_segment['start'], sponsor_segment['end']):
            w['category'] = sponsor_segment['category']

    return words


def generate_labelled_segments(words, tokenizer, segmentation_args, sponsor_segments):
    segments = generate_segments(words, tokenizer, segmentation_args)

    labelled_segments = list(
        map(lambda x: add_labels_to_words(x, sponsor_segments), segments))

    return labelled_segments


def word_start(word):
    return word['start']


def word_end(word):
    return word.get('end', word['start'])


def generate_segments(words, tokenizer, segmentation_args):

    cleaned_words_list = []
    for w in words:
        w['cleaned'] = preprocess.clean_text(w['text'])
        cleaned_words_list.append(w['cleaned'])

    # Get lengths of tokenized words
    num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False,
                                truncation=True, return_attention_mask=False, return_length=True).length

    first_pass_segments = []
    for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
        word['num_tokens'] = num_tokens

        # Add new segment
        if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
            first_pass_segments.append([word])

        else:  # Add to current segment
            first_pass_segments[-1].append(word)

    max_q_size = round(SAFETY_TOKENS_PERCENTAGE * tokenizer.model_max_length)

    buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size  # tokenizer.model_max_length

    # In second pass, we split those segments if too big
    second_pass_segments = []

    for segment in first_pass_segments:
        current_segment_num_tokens = 0
        current_segment = []
        after_split_segments = []
        for word in segment:
            new_seg = current_segment_num_tokens + \
                word['num_tokens'] >= max_q_size
            if new_seg:
                # Adding this token would make it have too many tokens
                # We save this batch and create new
                after_split_segments.append(current_segment)

            # Add tokens to current segment
            current_segment.append(word)
            current_segment_num_tokens += word['num_tokens']

            if not new_seg:
                continue

            # Just created a new segment, so we remove until we only have buffer_size tokens
            last_index = 0
            while current_segment_num_tokens > buffer_size and current_segment:
                current_segment_num_tokens -= current_segment[last_index]['num_tokens']
                last_index += 1

            current_segment = current_segment[last_index:]

        if current_segment:  # Add remaining segment
            after_split_segments.append(current_segment)

        # TODO if len(after_split_segments) > 1, a split occurred

        second_pass_segments.extend(after_split_segments)

    # Cleaning up, delete 'num_tokens' from each word
    for word in words:
        word.pop('num_tokens', None)

    return second_pass_segments


def extract_segment(words, start, end, map_function=None):
    """Extracts all words with time in [start, end]"""
    if words is None:
        words = []

    a = max(binary_search_below(words, 0, len(words), start), 0)
    b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))

    to_transform = map_function is not None and callable(map_function)

    return [
        map_function(words[i]) if to_transform else words[i] for i in range(a, b)
    ]


def avg(*items):
    return sum(items)/len(items)


def binary_search_below(transcript, start_index, end_index, time):
    if start_index >= end_index:
        return end_index

    middle_index = (start_index + end_index) // 2
    middle = transcript[middle_index]
    middle_time = avg(word_start(middle), word_end(middle))

    if time <= middle_time:
        return binary_search_below(transcript, start_index, middle_index, time)
    else:
        return binary_search_below(transcript, middle_index + 1, end_index, time)


def binary_search_above(transcript, start_index, end_index, time):
    if start_index >= end_index:
        return end_index

    middle_index = (start_index + end_index + 1) // 2
    middle = transcript[middle_index]
    middle_time = avg(word_start(middle), word_end(middle))

    if time >= middle_time:
        return binary_search_above(transcript, middle_index, end_index, time)
    else:
        return binary_search_above(transcript, start_index, middle_index - 1, time)