import numpy as np import torch from transformers import TimesformerForVideoClassification from preprocessing import read_video import logging import json import traceback # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, model_dir): self.model = TimesformerForVideoClassification.from_pretrained( 'donghuna/timesformer-base-finetuned-k400-diving48', ignore_mismatched_sizes=True ) self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes self.model.eval() def __call__(self, data): try: inputs = data.get("inputs", {}) video_base64 = inputs.get("video_base64", "") processed_frames = read_video(video_base64) return {"processed_frames": processed_frames} # 디버깅: 입력 데이터 확인 logger.info(f"Received data: {data}") frames = np.array(data['frames']) frames = torch.tensor(frames).float() # Ensure the data is in the correct format # 디버깅: 프레임 데이터 확인 logger.info(f"Frames shape: {frames.shape}") # Perform inference with torch.no_grad(): outputs = self.model(frames.unsqueeze(0)) # Add batch dimension predictions = torch.softmax(outputs.logits, dim=-1) # 디버깅: 예측 결과 확인 logger.info(f"Predictions: {predictions}") predicted_class = torch.argmax(predictions, dim=-1).item() # 디버깅: 예측 클래스 확인 logger.info(f"Predicted class: {predicted_class}") return {"predicted_class": predicted_class, "predictions": predictions.tolist()} except Exception as e: error_message = str(e) stack_trace = traceback.format_exc() logger.error(f"Error: {error_message}") logger.error(f"Stack trace: {stack_trace}") return json.dumps({ "status": "error", "message": error_message, "stack_trace": stack_trace }), 500