PromoDetect / src /classify.py
Shad0ws's picture
Upload 21 files
b7b7347
from transformers import TextClassificationPipeline
import preprocess
import segment
class SponsorBlockClassificationPipeline(TextClassificationPipeline):
def __init__(self, model, tokenizer):
device = next(model.parameters()).device.index
if device is None:
device = -1
super().__init__(model=model, tokenizer=tokenizer,
return_all_scores=True, truncation=True, device=device)
def preprocess(self, data, **tokenizer_kwargs):
# TODO add support for lists
texts = []
if not isinstance(data, list):
data = [data]
for d in data:
if isinstance(d, dict): # Otherwise, get data from transcript
words = preprocess.get_words(d['video_id'])
segment_words = segment.extract_segment(
words, d['start'], d['end'])
text = preprocess.clean_text(
' '.join(x['text'] for x in segment_words))
texts.append(text)
elif isinstance(d, str): # If string, assume this is what user wants to classify
texts.append(d)
else:
raise ValueError(f'Invalid input type: "{type(d)}"')
return self.tokenizer(
texts, return_tensors=self.framework, **tokenizer_kwargs)
def main():
pass
if __name__ == '__main__':
main()