File size: 2,352 Bytes
5b4786e
 
 
8b2bb9b
54947ac
5d6e7ad
d232895
54947ac
 
 
 
5b4786e
 
a5cd004
1f30fb0
 
 
 
5b4786e
 
 
 
1ced666
7af0913
f345b61
 
1ced666
a2e5f1c
1ced666
 
 
 
 
 
 
 
 
fca2c62
1ced666
 
 
 
 
 
 
 
 
 
 
 
 
 
fca2c62
1ced666
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import torch
from transformers import TimesformerForVideoClassification
from preprocessing import read_video
import logging
import json
import traceback

# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, model_dir):
        self.model = TimesformerForVideoClassification.from_pretrained(
            'donghuna/timesformer-base-finetuned-k400-diving48',
            ignore_mismatched_sizes=True
        )
        self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48)  # 48 output classes
        self.model.eval()
        
    def __call__(self, data):
        try:
            inputs = data.get("inputs", {})
            video_base64 = inputs.get("video_base64", "")
            processed_frames = read_video(video_base64)
    
            return {"processed_frames": processed_frames}
            
            # ๋””๋ฒ„๊น…: ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ํ™•์ธ
            logger.info(f"Received data: {data}")
    
            frames = np.array(data['frames'])
            frames = torch.tensor(frames).float()  # Ensure the data is in the correct format
    
            # ๋””๋ฒ„๊น…: ํ”„๋ ˆ์ž„ ๋ฐ์ดํ„ฐ ํ™•์ธ
            logger.info(f"Frames shape: {frames.shape}")
            
            # Perform inference
            with torch.no_grad():
                outputs = self.model(frames.unsqueeze(0))  # Add batch dimension
                predictions = torch.softmax(outputs.logits, dim=-1)
    
                # ๋””๋ฒ„๊น…: ์˜ˆ์ธก ๊ฒฐ๊ณผ ํ™•์ธ
                logger.info(f"Predictions: {predictions}")
                
                predicted_class = torch.argmax(predictions, dim=-1).item()
    
                # ๋””๋ฒ„๊น…: ์˜ˆ์ธก ํด๋ž˜์Šค ํ™•์ธ
                logger.info(f"Predicted class: {predicted_class}")
            
            return {"predicted_class": predicted_class, "predictions": predictions.tolist()}


        except Exception as e:
            error_message = str(e)
            stack_trace = traceback.format_exc()
            logger.error(f"Error: {error_message}")
            logger.error(f"Stack trace: {stack_trace}")
            return json.dumps({
                "status": "error",
                "message": error_message,
                "stack_trace": stack_trace
            }), 500