dschandra commited on
Commit
e0af381
·
verified ·
1 Parent(s): 54dcee8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -14
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import DetrImageProcessor, DetrForObjectDetection
4
  import cv2
5
  import numpy as np
6
  import tempfile
@@ -19,23 +19,35 @@ st.write("Upload a thermal video (MP4) to detect thermal, dust, and power genera
19
 
20
  # UI controls for optimization parameters
21
  st.sidebar.header("Analysis Settings")
22
- frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=20, value=10)
23
- batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=16, value=8)
24
  resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True)
25
  resize_width = 640 if resize_enabled else None
 
26
 
27
  # Load model and processor
28
  @st.cache_resource
29
- def load_model():
30
  warnings.filterwarnings("ignore", message="Some weights of the model checkpoint.*were not used")
31
  logging.set_verbosity_error()
32
 
33
- processor = DetrImageProcessor.from_pretrained("hustvl/yolos-tiny")
34
- model = DetrForObjectDetection.from_pretrained("hustvl/yolos-tiny")
35
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
- model.to(device)
37
- model.eval()
38
- return processor, model, device
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  processor, model, device = load_model()
41
 
@@ -50,7 +62,6 @@ def resize_frame(frame, width=None):
50
  # Function to process a batch of frames
51
  async def detect_faults_batch(frames, processor, model, device):
52
  try:
53
- # Resize frames if enabled
54
  frames = [resize_frame(frame, resize_width) for frame in frames]
55
  inputs = processor(images=frames, return_tensors="pt").to(device)
56
  with torch.no_grad():
@@ -87,7 +98,6 @@ async def detect_faults_batch(frames, processor, model, device):
87
  annotated_frames.append(annotated_frame)
88
  all_faults.append(faults)
89
 
90
- # Clear GPU memory
91
  if torch.cuda.is_available():
92
  torch.cuda.empty_cache()
93
 
@@ -109,7 +119,6 @@ async def process_video(video_path, frame_skip, batch_size):
109
  fps = int(cap.get(cv2.CAP_PROP_FPS))
110
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
111
 
112
- # Adjust output size if resizing
113
  out_width = resize_width if resize_width else frame_width
114
  out_height = int(out_width * frame_height / frame_width) if resize_width else frame_height
115
 
@@ -132,7 +141,6 @@ async def process_video(video_path, frame_skip, batch_size):
132
  break
133
 
134
  if frame_count % frame_skip != 0:
135
- # Resize frame for output if needed
136
  frame = resize_frame(frame, resize_width)
137
  out.write(frame)
138
  frame_count += 1
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import YolosImageProcessor, YolosForObjectDetection
4
  import cv2
5
  import numpy as np
6
  import tempfile
 
19
 
20
  # UI controls for optimization parameters
21
  st.sidebar.header("Analysis Settings")
22
+ frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=30, value=15)
23
+ batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=16, value=12)
24
  resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True)
25
  resize_width = 640 if resize_enabled else None
26
+ quantize_model = st.sidebar.checkbox("Quantize Model (faster on CPU)", value=False)
27
 
28
  # Load model and processor
29
  @st.cache_resource
30
+ def load_model(quantize=quantize_model):
31
  warnings.filterwarnings("ignore", message="Some weights of the model checkpoint.*were not used")
32
  logging.set_verbosity_error()
33
 
34
+ try:
35
+ processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny")
36
+ model = YolosForObjectDetection.from_pretrained("hustvl/yolos-tiny")
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ model.to(device)
39
+
40
+ # Apply dynamic quantization for CPU if enabled
41
+ if quantize and device.type == "cpu":
42
+ model = torch.quantization.quantize_dynamic(
43
+ model, {torch.nn.Linear}, dtype=torch.qint8
44
+ )
45
+
46
+ model.eval()
47
+ return processor, model, device
48
+ except Exception as e:
49
+ st.error(f"Failed to load model: {str(e)}. Please check your internet connection or clear the cache (~/.cache/huggingface/hub).")
50
+ raise
51
 
52
  processor, model, device = load_model()
53
 
 
62
  # Function to process a batch of frames
63
  async def detect_faults_batch(frames, processor, model, device):
64
  try:
 
65
  frames = [resize_frame(frame, resize_width) for frame in frames]
66
  inputs = processor(images=frames, return_tensors="pt").to(device)
67
  with torch.no_grad():
 
98
  annotated_frames.append(annotated_frame)
99
  all_faults.append(faults)
100
 
 
101
  if torch.cuda.is_available():
102
  torch.cuda.empty_cache()
103
 
 
119
  fps = int(cap.get(cv2.CAP_PROP_FPS))
120
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
121
 
 
122
  out_width = resize_width if resize_width else frame_width
123
  out_height = int(out_width * frame_height / frame_width) if resize_width else frame_height
124
 
 
141
  break
142
 
143
  if frame_count % frame_skip != 0:
 
144
  frame = resize_frame(frame, resize_width)
145
  out.write(frame)
146
  frame_count += 1