donghuna commited on
Commit
5b4786e
·
verified ·
1 Parent(s): cb27a5d

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +78 -0
handler.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import io
4
+ import numpy as np
5
+ import torch
6
+ from torchvision import transforms
7
+ from transformers import TimesformerForVideoClassification
8
+ from ftplib import FTP
9
+ import av
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, ftp_host, ftp_user, ftp_password, model_dir=""):
13
+ self.model = TimesformerForVideoClassification.from_pretrained(model_dir)
14
+ self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes
15
+ self.model.eval()
16
+
17
+ # FTP connection details - update these details as required
18
+ self.ftp_host = ftp_host
19
+ self.ftp_user = ftp_user
20
+ self.ftp_password = ftp_password
21
+
22
+ # Target size and number of frames
23
+ self.target_size = (224, 224)
24
+ self.num_frames = 24
25
+
26
+ def __call__(self, data):
27
+ video_path = data.get("video_path")
28
+ # start_frame = data.get("start_frame", 0)
29
+ # end_frame = data.get("end_frame", 48) # Default end frame, can be adjusted
30
+
31
+ # Connect to FTP and read video
32
+ with FTP(self.ftp_host) as ftp:
33
+ ftp.login(self.ftp_user, self.ftp_password)
34
+ video_tensor = self.read_and_process_video(ftp, video_path, start_frame, end_frame, self.target_size, self.num_frames)
35
+
36
+ # Perform inference
37
+ with torch.no_grad():
38
+ outputs = self.model(video_tensor.unsqueeze(0)) # Add batch dimension
39
+ predictions = torch.softmax(outputs.logits, dim=-1)
40
+ predicted_class = torch.argmax(predictions, dim=-1).item()
41
+
42
+ return {"predicted_class": predicted_class, "predictions": predictions.tolist()}
43
+
44
+ def read_video_from_ftp(self, ftp, file_path, start_frame, end_frame):
45
+ video_data = io.BytesIO()
46
+ ftp.retrbinary(f'RETR {file_path}', video_data.write)
47
+ video_data.seek(0)
48
+ container = av.open(video_data, format='mp4')
49
+ frames = [frame.to_ndarray(format="rgb24").astype(np.uint8) for frame in container.decode(video=0)]
50
+ return np.stack(frames, axis=0)
51
+
52
+ def sample_frames(self, frames, num_frames):
53
+ total_frames = len(frames)
54
+ sampled_frames = list(frames)
55
+ if total_frames <= num_frames:
56
+ if total_frames < num_frames:
57
+ padding = [np.zeros_like(frames[0]) for _ in range(num_frames - total_frames)]
58
+ sampled_frames.extend(padding)
59
+ else:
60
+ indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int)
61
+ sampled_frames = [frames[i] for i in indices]
62
+ return np.array(sampled_frames)
63
+
64
+ def pad_and_resize(self, frames, target_size):
65
+ transform = transforms.Compose([
66
+ transforms.ToPILImage(),
67
+ transforms.Resize(target_size),
68
+ transforms.ToTensor()
69
+ ])
70
+ processed_frames = [transform(frame) for frame in frames]
71
+ return torch.stack(processed_frames)
72
+
73
+ def read_and_process_video(self, ftp, file_path, start_frame, end_frame, target_size, num_frames):
74
+ frames = self.read_video_from_ftp(ftp, file_path, start_frame, end_frame)
75
+ frames = self.sample_frames(frames, num_frames=num_frames)
76
+ processed_frames = self.pad_and_resize(frames, target_size=target_size)
77
+ processed_frames = processed_frames.permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W)
78
+ return processed_frames