|
import numpy as np |
|
import torch |
|
from transformers import TimesformerForVideoClassification |
|
from preprocessing import read_video |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
self.model = TimesformerForVideoClassification.from_pretrained( |
|
'donghuna/timesformer-base-finetuned-k400-diving48', |
|
ignore_mismatched_sizes=True |
|
) |
|
self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) |
|
self.model.eval() |
|
|
|
def __call__(self, data): |
|
video_path = data["inputs"]["video_path"] |
|
ftp_password = data["inputs"].get("ftp_password") |
|
|
|
processed_frames = read_video(video_path, ftp_password) |
|
|
|
|
|
|
|
|
|
logger.info(f"Received data: {data}") |
|
|
|
frames = np.array(data['frames']) |
|
frames = torch.tensor(frames).float() |
|
|
|
|
|
logger.info(f"Frames shape: {frames.shape}") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(frames.unsqueeze(0)) |
|
predictions = torch.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
logger.info(f"Predictions: {predictions}") |
|
|
|
predicted_class = torch.argmax(predictions, dim=-1).item() |
|
|
|
|
|
logger.info(f"Predicted class: {predicted_class}") |
|
|
|
return {"predicted_class": predicted_class, "predictions": predictions.tolist()} |
|
|