import numpy as np import torch from transformers import TimesformerForVideoClassification from preprocessing import read_video import logging import json import traceback import os from typing import Dict, List, Any # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, model_dir): self.model = TimesformerForVideoClassification.from_pretrained( 'donghuna/timesformer-base-finetuned-k400-diving48', ignore_mismatched_sizes=True ) self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes self.model.eval() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`): base64 encoded video data date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data.get("inputs") read_video(inputs) return inputs[0:10] # inputs = data.get("inputs") # if not inputs: # return {"error": "No video input provided"} # # 비디오 파일 경로 # video_path = inputs.get("video_path") # if not video_path or not os.path.exists(video_path): # return {"error": "Invalid or missing video file"} # return {"predicted_class": 1}