DSatishchandra commited on
Commit
b3f84ea
·
verified ·
1 Parent(s): ca750b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -2
app.py CHANGED
@@ -2,6 +2,9 @@ import gradio as gr
2
  import torch
3
  from model import CNNLSTMClassifier
4
  from utils import extract_frames
 
 
 
5
 
6
  model = CNNLSTMClassifier()
7
  model.load_state_dict(torch.load("lbw_classifier.pt", map_location='cpu'))
@@ -13,14 +16,39 @@ def predict(video_file):
13
  if isinstance(video_file, dict) and "name" in video_file:
14
  video_path = video_file["name"]
15
  else:
16
- video_path = video_file # fallback for older Gradio behavior
17
 
 
18
  frames = extract_frames(video_path)
19
  with torch.no_grad():
20
  output = model(frames)
21
  pred = torch.argmax(output, dim=1).item()
22
  prob = torch.softmax(output, dim=1)[0][pred].item()
23
- return f"Prediction: {classes[pred]} (Confidence: {prob:.2%})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  iface = gr.Interface(
26
  fn=predict,
 
2
  import torch
3
  from model import CNNLSTMClassifier
4
  from utils import extract_frames
5
+ import shutil
6
+ import os
7
+ import cv2
8
 
9
  model = CNNLSTMClassifier()
10
  model.load_state_dict(torch.load("lbw_classifier.pt", map_location='cpu'))
 
16
  if isinstance(video_file, dict) and "name" in video_file:
17
  video_path = video_file["name"]
18
  else:
19
+ video_path = video_file
20
 
21
+ # Predict
22
  frames = extract_frames(video_path)
23
  with torch.no_grad():
24
  output = model(frames)
25
  pred = torch.argmax(output, dim=1).item()
26
  prob = torch.softmax(output, dim=1)[0][pred].item()
27
+
28
+ label = f"{classes[pred]} ({prob:.2%})"
29
+
30
+ # Create annotated video
31
+ cap = cv2.VideoCapture(video_path)
32
+ out_path = "/tmp/annotated_video.mp4"
33
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
34
+ fps = cap.get(cv2.CAP_PROP_FPS)
35
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
36
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
37
+ out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
38
+
39
+ font = cv2.FONT_HERSHEY_SIMPLEX
40
+ color = (0, 255, 0) if pred == 1 else (0, 0, 255)
41
+
42
+ while True:
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ break
46
+ cv2.putText(frame, label, (30, 60), font, 2, color, 4, cv2.LINE_AA)
47
+ out.write(frame)
48
+ cap.release()
49
+ out.release()
50
+
51
+ return out_path
52
 
53
  iface = gr.Interface(
54
  fn=predict,