|
import numpy as np |
|
import torch |
|
from transformers import TimesformerForVideoClassification |
|
from preprocessing import read_video |
|
import logging |
|
import json |
|
import traceback |
|
import os |
|
from typing import Dict, List, Any |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
self.model = TimesformerForVideoClassification.from_pretrained( |
|
'donghuna/timesformer-base-finetuned-k400-diving48', |
|
ignore_mismatched_sizes=True |
|
) |
|
self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) |
|
self.model.eval() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`): base64 encoded video data |
|
date (:obj: `str`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.get("inputs") |
|
videos = read_video(inputs) |
|
with torch.no_grad(): |
|
outputs = self.model(videos) |
|
logits = outputs.logits |
|
_, predicted = torch.max(logits, 1) |
|
return predicted.tolist() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|