connecting / handler.py~
Ashlee Kupor
Add model
5cefadd
raw
history blame contribute delete
816 Bytes
from simpletransformers.classification import ClassificationModel, ClassificationArgs
from typing import Dict, List, Any
import pandas as pd
import webvtt
from datetime import datetime
import torch
import spacy
nlp = spacy.load("en_core_web_sm")
tokenizer = nlp.tokenizer
token_limit = 200
class EndpointHandler():
def __init__(self, path="."):
print("Loading models...")
cuda_available = torch.cuda.is_available()
self.model = ClassificationModel(
"roberta", path, use_cuda=cuda_available
)
def __call__(self, data_file: str) -> List[Dict[str, Any]]:
''' data_file is a str pointing to filename of type .vtt '''
utterances_list = []
predictions, raw_outputs = self.model.predict(utterances_list)
return predictions