Create handler.py
Browse files- handler.py +78 -0
handler.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import io
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torchvision import transforms
|
7 |
+
from transformers import TimesformerForVideoClassification
|
8 |
+
from ftplib import FTP
|
9 |
+
import av
|
10 |
+
|
11 |
+
class EndpointHandler:
|
12 |
+
def __init__(self, ftp_host, ftp_user, ftp_password, model_dir=""):
|
13 |
+
self.model = TimesformerForVideoClassification.from_pretrained(model_dir)
|
14 |
+
self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes
|
15 |
+
self.model.eval()
|
16 |
+
|
17 |
+
# FTP connection details - update these details as required
|
18 |
+
self.ftp_host = ftp_host
|
19 |
+
self.ftp_user = ftp_user
|
20 |
+
self.ftp_password = ftp_password
|
21 |
+
|
22 |
+
# Target size and number of frames
|
23 |
+
self.target_size = (224, 224)
|
24 |
+
self.num_frames = 24
|
25 |
+
|
26 |
+
def __call__(self, data):
|
27 |
+
video_path = data.get("video_path")
|
28 |
+
# start_frame = data.get("start_frame", 0)
|
29 |
+
# end_frame = data.get("end_frame", 48) # Default end frame, can be adjusted
|
30 |
+
|
31 |
+
# Connect to FTP and read video
|
32 |
+
with FTP(self.ftp_host) as ftp:
|
33 |
+
ftp.login(self.ftp_user, self.ftp_password)
|
34 |
+
video_tensor = self.read_and_process_video(ftp, video_path, start_frame, end_frame, self.target_size, self.num_frames)
|
35 |
+
|
36 |
+
# Perform inference
|
37 |
+
with torch.no_grad():
|
38 |
+
outputs = self.model(video_tensor.unsqueeze(0)) # Add batch dimension
|
39 |
+
predictions = torch.softmax(outputs.logits, dim=-1)
|
40 |
+
predicted_class = torch.argmax(predictions, dim=-1).item()
|
41 |
+
|
42 |
+
return {"predicted_class": predicted_class, "predictions": predictions.tolist()}
|
43 |
+
|
44 |
+
def read_video_from_ftp(self, ftp, file_path, start_frame, end_frame):
|
45 |
+
video_data = io.BytesIO()
|
46 |
+
ftp.retrbinary(f'RETR {file_path}', video_data.write)
|
47 |
+
video_data.seek(0)
|
48 |
+
container = av.open(video_data, format='mp4')
|
49 |
+
frames = [frame.to_ndarray(format="rgb24").astype(np.uint8) for frame in container.decode(video=0)]
|
50 |
+
return np.stack(frames, axis=0)
|
51 |
+
|
52 |
+
def sample_frames(self, frames, num_frames):
|
53 |
+
total_frames = len(frames)
|
54 |
+
sampled_frames = list(frames)
|
55 |
+
if total_frames <= num_frames:
|
56 |
+
if total_frames < num_frames:
|
57 |
+
padding = [np.zeros_like(frames[0]) for _ in range(num_frames - total_frames)]
|
58 |
+
sampled_frames.extend(padding)
|
59 |
+
else:
|
60 |
+
indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int)
|
61 |
+
sampled_frames = [frames[i] for i in indices]
|
62 |
+
return np.array(sampled_frames)
|
63 |
+
|
64 |
+
def pad_and_resize(self, frames, target_size):
|
65 |
+
transform = transforms.Compose([
|
66 |
+
transforms.ToPILImage(),
|
67 |
+
transforms.Resize(target_size),
|
68 |
+
transforms.ToTensor()
|
69 |
+
])
|
70 |
+
processed_frames = [transform(frame) for frame in frames]
|
71 |
+
return torch.stack(processed_frames)
|
72 |
+
|
73 |
+
def read_and_process_video(self, ftp, file_path, start_frame, end_frame, target_size, num_frames):
|
74 |
+
frames = self.read_video_from_ftp(ftp, file_path, start_frame, end_frame)
|
75 |
+
frames = self.sample_frames(frames, num_frames=num_frames)
|
76 |
+
processed_frames = self.pad_and_resize(frames, target_size=target_size)
|
77 |
+
processed_frames = processed_frames.permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W)
|
78 |
+
return processed_frames
|