Spaces:
Sleeping
Sleeping
File size: 1,408 Bytes
b7b7347 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from transformers import TextClassificationPipeline
import preprocess
import segment
class SponsorBlockClassificationPipeline(TextClassificationPipeline):
def __init__(self, model, tokenizer):
device = next(model.parameters()).device.index
if device is None:
device = -1
super().__init__(model=model, tokenizer=tokenizer,
return_all_scores=True, truncation=True, device=device)
def preprocess(self, data, **tokenizer_kwargs):
# TODO add support for lists
texts = []
if not isinstance(data, list):
data = [data]
for d in data:
if isinstance(d, dict): # Otherwise, get data from transcript
words = preprocess.get_words(d['video_id'])
segment_words = segment.extract_segment(
words, d['start'], d['end'])
text = preprocess.clean_text(
' '.join(x['text'] for x in segment_words))
texts.append(text)
elif isinstance(d, str): # If string, assume this is what user wants to classify
texts.append(d)
else:
raise ValueError(f'Invalid input type: "{type(d)}"')
return self.tokenizer(
texts, return_tensors=self.framework, **tokenizer_kwargs)
def main():
pass
if __name__ == '__main__':
main()
|