File size: 1,781 Bytes
5b4786e
 
 
8b2bb9b
54947ac
5d6e7ad
d232895
49cd757
32e93b2
54947ac
 
 
 
5b4786e
 
a5cd004
1f30fb0
 
 
 
c3b75f9
5b4786e
 
b9431dc
32e93b2
 
293340f
32e93b2
293340f
 
32e93b2
 
 
07d81c1
8efce6f
2990f6d
 
8efce6f
2990f6d
8efce6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9431dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import torch
from transformers import TimesformerForVideoClassification
from preprocessing import read_video
import logging
import json
import traceback
import os
from typing import Dict, List, Any

# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, model_dir):
        self.model = TimesformerForVideoClassification.from_pretrained(
            'donghuna/timesformer-base-finetuned-k400-diving48',
            ignore_mismatched_sizes=True
        )
        # self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48)  # 48 output classes
        self.model.eval()
        
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj:`str`): base64 encoded video data
        Return:
            A :obj:`list` | `dict`: A list of dictionaries with the top 3 class indices and their probabilities
                                   for each input video.
        """

        inputs = data.get("inputs")
        videos = read_video(inputs)
        
        with torch.no_grad():
            outputs = self.model(videos)
        
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=1)
        
        # Top 3
        top_probs, top_indices = torch.topk(probabilities, 3, dim=1)
        
        top_probs_list = top_probs.tolist()
        top_indices_list = top_indices.tolist()
        
        top_results = []
        for i in range(len(top_indices_list)):
            top_results.append({
                "class_indices": top_indices_list[i],
                "probabilities": top_probs_list[i]
            })
        
        return top_results