donghuna's picture
Update handler.py
e4249e7 verified
raw
history blame
1.77 kB
import numpy as np
import torch
from transformers import TimesformerForVideoClassification
from preprocessing import read_video
import logging
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir):
self.model = TimesformerForVideoClassification.from_pretrained(
'donghuna/timesformer-base-finetuned-k400-diving48',
ignore_mismatched_sizes=True
)
self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes
self.model.eval()
def __call__(self, data):
video_path = data["inputs"]["video_path"]
ftp_password = data["inputs"].get("ftp_password")
processed_frames = read_video(video_path, ftp_password)
# ๋””๋ฒ„๊น…: ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ํ™•์ธ
logger.info(f"Received data: {data}")
frames = np.array(data['frames'])
frames = torch.tensor(frames).float() # Ensure the data is in the correct format
# ๋””๋ฒ„๊น…: ํ”„๋ ˆ์ž„ ๋ฐ์ดํ„ฐ ํ™•์ธ
logger.info(f"Frames shape: {frames.shape}")
# Perform inference
with torch.no_grad():
outputs = self.model(frames.unsqueeze(0)) # Add batch dimension
predictions = torch.softmax(outputs.logits, dim=-1)
# ๋””๋ฒ„๊น…: ์˜ˆ์ธก ๊ฒฐ๊ณผ ํ™•์ธ
logger.info(f"Predictions: {predictions}")
predicted_class = torch.argmax(predictions, dim=-1).item()
# ๋””๋ฒ„๊น…: ์˜ˆ์ธก ํด๋ž˜์Šค ํ™•์ธ
logger.info(f"Predicted class: {predicted_class}")
return {"predicted_class": predicted_class, "predictions": predictions.tolist()}