Souha Ben Hassine commited on
Commit
c32f512
·
1 Parent(s): e017cd8

initial commit

Browse files
Files changed (1) hide show
  1. app.py +28 -22
app.py CHANGED
@@ -126,30 +126,36 @@ def load_frames_from_video(video_path, transform, num_frames=10):
126
  frames = frames.unsqueeze(0) # Add batch dimension
127
  return frames
128
 
129
- # Prediction function for a single video
130
  def predict_video(model, video_path, text_input, tokenizer, transform):
131
- # Set model to evaluation mode
132
- model.eval()
133
-
134
- # Tokenize the text input and move to device
135
- encoding = tokenizer(
136
- text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt'
137
- )
138
- encoding = {key: val.to(device) for key, val in encoding.items()} # Ensure text input is on the device
139
-
140
- # Load frames from the video and move to device
141
- frames = load_frames_from_video(video_path, transform)
142
- frames = frames.to(device) # Ensure frames are on the device
143
-
144
- # Perform forward pass through the model
145
- with torch.no_grad():
146
- output = model(encoding, frames)
147
-
148
- # Apply sigmoid to get probability, then threshold to get prediction
149
- prediction = (output.squeeze(-1) > 0.5).float()
 
 
 
 
 
 
150
 
151
- # Return the predicted label (0 or 1)
152
- return prediction.item()
 
153
 
154
 
155
 
 
126
  frames = frames.unsqueeze(0) # Add batch dimension
127
  return frames
128
 
 
129
  def predict_video(model, video_path, text_input, tokenizer, transform):
130
+ try:
131
+ # Set model to evaluation mode
132
+ model.eval()
133
+
134
+ # Tokenize the text input
135
+ encoding = tokenizer(
136
+ text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt'
137
+ )
138
+ encoding = {key: val.to(device) for key, val in encoding.items()}
139
+
140
+ # Load frames from the video
141
+ frames = load_frames_from_video(video_path, transform)
142
+ frames = frames.to(device)
143
+
144
+ # Log input shapes and devices
145
+ print(f"Encoding device: {next(iter(encoding.values())).device}, Frames shape: {frames.shape}")
146
+
147
+ # Perform forward pass through the model
148
+ with torch.no_grad():
149
+ output = model(encoding, frames)
150
+
151
+ # Apply sigmoid to get probability, then threshold to get prediction
152
+ prediction = (output.squeeze(-1) > 0.5).float()
153
+
154
+ return prediction.item()
155
 
156
+ except Exception as e:
157
+ print(f"Prediction error: {e}")
158
+ return "Error during prediction"
159
 
160
 
161