File size: 1,772 Bytes
5b4786e 8b2bb9b 54947ac 5b4786e a5cd004 1f30fb0 5b4786e 8b2bb9b e4249e7 8b2bb9b fca2c62 54947ac fca2c62 a5cd004 fca2c62 54947ac 5b4786e a5cd004 5b4786e fca2c62 54947ac fca2c62 5b4786e fca2c62 54947ac 5b4786e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import numpy as np
import torch
from transformers import TimesformerForVideoClassification
from preprocessing import read_video
import logging
# ๋ก๊น
์ค์
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir):
self.model = TimesformerForVideoClassification.from_pretrained(
'donghuna/timesformer-base-finetuned-k400-diving48',
ignore_mismatched_sizes=True
)
self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes
self.model.eval()
def __call__(self, data):
video_path = data["inputs"]["video_path"]
ftp_password = data["inputs"].get("ftp_password")
processed_frames = read_video(video_path, ftp_password)
# ๋๋ฒ๊น
: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ธ
logger.info(f"Received data: {data}")
frames = np.array(data['frames'])
frames = torch.tensor(frames).float() # Ensure the data is in the correct format
# ๋๋ฒ๊น
: ํ๋ ์ ๋ฐ์ดํฐ ํ์ธ
logger.info(f"Frames shape: {frames.shape}")
# Perform inference
with torch.no_grad():
outputs = self.model(frames.unsqueeze(0)) # Add batch dimension
predictions = torch.softmax(outputs.logits, dim=-1)
# ๋๋ฒ๊น
: ์์ธก ๊ฒฐ๊ณผ ํ์ธ
logger.info(f"Predictions: {predictions}")
predicted_class = torch.argmax(predictions, dim=-1).item()
# ๋๋ฒ๊น
: ์์ธก ํด๋์ค ํ์ธ
logger.info(f"Predicted class: {predicted_class}")
return {"predicted_class": predicted_class, "predictions": predictions.tolist()}
|