|
from simpletransformers.classification import ClassificationModel, ClassificationArgs |
|
from typing import Dict, List, Any |
|
import pandas as pd |
|
import webvtt |
|
from datetime import datetime |
|
import torch |
|
import spacy |
|
|
|
nlp = spacy.load("en_core_web_sm") |
|
tokenizer = nlp.tokenizer |
|
token_limit = 200 |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path="."): |
|
print("Loading models...") |
|
cuda_available = torch.cuda.is_available() |
|
self.model = ClassificationModel( |
|
"roberta", path, use_cuda=cuda_available |
|
) |
|
|
|
def __call__(self, data_file: str) -> List[Dict[str, Any]]: |
|
''' data_file is a str pointing to filename of type .vtt ''' |
|
|
|
utterances_list = [] |
|
predictions, raw_outputs = self.model.predict(utterances_list) |
|
|
|
return predictions |
|
|