|
import numpy as np |
|
import torch |
|
from transformers import TimesformerForVideoClassification |
|
from preprocessing import read_video |
|
import logging |
|
import json |
|
import traceback |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
self.model = TimesformerForVideoClassification.from_pretrained( |
|
'donghuna/timesformer-base-finetuned-k400-diving48', |
|
ignore_mismatched_sizes=True |
|
) |
|
self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) |
|
self.model.eval() |
|
|
|
def __call__(self, data): |
|
try: |
|
inputs = data.get("inputs", {}) |
|
video_base64 = inputs.get("video_base64", "") |
|
processed_frames = read_video(video_base64) |
|
|
|
return {"processed_frames": processed_frames} |
|
|
|
|
|
logger.info(f"Received data: {data}") |
|
|
|
frames = np.array(data['frames']) |
|
frames = torch.tensor(frames).float() |
|
|
|
|
|
logger.info(f"Frames shape: {frames.shape}") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(frames.unsqueeze(0)) |
|
predictions = torch.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
logger.info(f"Predictions: {predictions}") |
|
|
|
predicted_class = torch.argmax(predictions, dim=-1).item() |
|
|
|
|
|
logger.info(f"Predicted class: {predicted_class}") |
|
|
|
return {"predicted_class": predicted_class, "predictions": predictions.tolist()} |
|
|
|
|
|
except Exception as e: |
|
error_message = str(e) |
|
stack_trace = traceback.format_exc() |
|
logger.error(f"Error: {error_message}") |
|
logger.error(f"Stack trace: {stack_trace}") |
|
return json.dumps({ |
|
"status": "error", |
|
"message": error_message, |
|
"stack_trace": stack_trace |
|
}), 500 |