File size: 2,352 Bytes
5b4786e 8b2bb9b 54947ac 5d6e7ad d232895 54947ac 5b4786e a5cd004 1f30fb0 5b4786e 1ced666 7af0913 f345b61 1ced666 a2e5f1c 1ced666 fca2c62 1ced666 fca2c62 1ced666 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import numpy as np
import torch
from transformers import TimesformerForVideoClassification
from preprocessing import read_video
import logging
import json
import traceback
# ๋ก๊น
์ค์
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir):
self.model = TimesformerForVideoClassification.from_pretrained(
'donghuna/timesformer-base-finetuned-k400-diving48',
ignore_mismatched_sizes=True
)
self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes
self.model.eval()
def __call__(self, data):
try:
inputs = data.get("inputs", {})
video_base64 = inputs.get("video_base64", "")
processed_frames = read_video(video_base64)
return {"processed_frames": processed_frames}
# ๋๋ฒ๊น
: ์
๋ ฅ ๋ฐ์ดํฐ ํ์ธ
logger.info(f"Received data: {data}")
frames = np.array(data['frames'])
frames = torch.tensor(frames).float() # Ensure the data is in the correct format
# ๋๋ฒ๊น
: ํ๋ ์ ๋ฐ์ดํฐ ํ์ธ
logger.info(f"Frames shape: {frames.shape}")
# Perform inference
with torch.no_grad():
outputs = self.model(frames.unsqueeze(0)) # Add batch dimension
predictions = torch.softmax(outputs.logits, dim=-1)
# ๋๋ฒ๊น
: ์์ธก ๊ฒฐ๊ณผ ํ์ธ
logger.info(f"Predictions: {predictions}")
predicted_class = torch.argmax(predictions, dim=-1).item()
# ๋๋ฒ๊น
: ์์ธก ํด๋์ค ํ์ธ
logger.info(f"Predicted class: {predicted_class}")
return {"predicted_class": predicted_class, "predictions": predictions.tolist()}
except Exception as e:
error_message = str(e)
stack_trace = traceback.format_exc()
logger.error(f"Error: {error_message}")
logger.error(f"Stack trace: {stack_trace}")
return json.dumps({
"status": "error",
"message": error_message,
"stack_trace": stack_trace
}), 500 |