from simpletransformers.classification import ClassificationModel, ClassificationArgs | |
from typing import Dict, List, Any | |
import pandas as pd | |
import webvtt | |
from datetime import datetime | |
import torch | |
import spacy | |
nlp = spacy.load("en_core_web_sm") | |
tokenizer = nlp.tokenizer | |
token_limit = 200 | |
class EndpointHandler(): | |
def __init__(self, path="."): | |
print("Loading models...") | |
cuda_available = torch.cuda.is_available() | |
self.model = ClassificationModel( | |
"roberta", path, use_cuda=cuda_available | |
) | |
def __call__(self, data_file: str) -> List[Dict[str, Any]]: | |
''' data_file is a str pointing to filename of type .vtt ''' | |
return [] |