donghuna's picture
Update handler.py
2990f6d verified
raw
history blame
1.63 kB
import numpy as np
import torch
from transformers import TimesformerForVideoClassification
from preprocessing import read_video
import logging
import json
import traceback
import os
from typing import Dict, List, Any
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir):
self.model = TimesformerForVideoClassification.from_pretrained(
'donghuna/timesformer-base-finetuned-k400-diving48',
ignore_mismatched_sizes=True
)
self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`): base64 encoded video data
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.get("inputs")
videos = read_video(inputs)
with torch.no_grad():
outputs = self.model(videos)
logits = outputs.logits
_, predicted = torch.max(logits, 1)
return predicted.tolist()
# inputs = data.get("inputs")
# if not inputs:
# return {"error": "No video input provided"}
# # ๋น„๋””์˜ค ํŒŒ์ผ ๊ฒฝ๋กœ
# video_path = inputs.get("video_path")
# if not video_path or not os.path.exists(video_path):
# return {"error": "Invalid or missing video file"}
# return {"predicted_class": 1}