donghuna's picture
Update handler.py
f345b61 verified
raw
history blame
2.35 kB
import numpy as np
import torch
from transformers import TimesformerForVideoClassification
from preprocessing import read_video
import logging
import json
import traceback
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir):
self.model = TimesformerForVideoClassification.from_pretrained(
'donghuna/timesformer-base-finetuned-k400-diving48',
ignore_mismatched_sizes=True
)
self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes
self.model.eval()
def __call__(self, data):
try:
inputs = data.get("inputs", {})
video_base64 = inputs.get("video_base64", "")
processed_frames = read_video(video_base64)
return {"processed_frames": processed_frames}
# ๋””๋ฒ„๊น…: ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ํ™•์ธ
logger.info(f"Received data: {data}")
frames = np.array(data['frames'])
frames = torch.tensor(frames).float() # Ensure the data is in the correct format
# ๋””๋ฒ„๊น…: ํ”„๋ ˆ์ž„ ๋ฐ์ดํ„ฐ ํ™•์ธ
logger.info(f"Frames shape: {frames.shape}")
# Perform inference
with torch.no_grad():
outputs = self.model(frames.unsqueeze(0)) # Add batch dimension
predictions = torch.softmax(outputs.logits, dim=-1)
# ๋””๋ฒ„๊น…: ์˜ˆ์ธก ๊ฒฐ๊ณผ ํ™•์ธ
logger.info(f"Predictions: {predictions}")
predicted_class = torch.argmax(predictions, dim=-1).item()
# ๋””๋ฒ„๊น…: ์˜ˆ์ธก ํด๋ž˜์Šค ํ™•์ธ
logger.info(f"Predicted class: {predicted_class}")
return {"predicted_class": predicted_class, "predictions": predictions.tolist()}
except Exception as e:
error_message = str(e)
stack_trace = traceback.format_exc()
logger.error(f"Error: {error_message}")
logger.error(f"Stack trace: {stack_trace}")
return json.dumps({
"status": "error",
"message": error_message,
"stack_trace": stack_trace
}), 500