dschandra commited on
Commit
0bc4e56
·
verified ·
1 Parent(s): e0af381

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -13
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import YolosImageProcessor, YolosForObjectDetection
4
  import cv2
5
  import numpy as np
6
  import tempfile
@@ -19,11 +19,11 @@ st.write("Upload a thermal video (MP4) to detect thermal, dust, and power genera
19
 
20
  # UI controls for optimization parameters
21
  st.sidebar.header("Analysis Settings")
22
- frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=30, value=15)
23
- batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=16, value=12)
24
  resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True)
25
- resize_width = 640 if resize_enabled else None
26
- quantize_model = st.sidebar.checkbox("Quantize Model (faster on CPU)", value=False)
27
 
28
  # Load model and processor
29
  @st.cache_resource
@@ -32,12 +32,12 @@ def load_model(quantize=quantize_model):
32
  logging.set_verbosity_error()
33
 
34
  try:
35
- processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny")
36
- model = YolosForObjectDetection.from_pretrained("hustvl/yolos-tiny")
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  model.to(device)
39
 
40
- # Apply dynamic quantization for CPU if enabled
41
  if quantize and device.type == "cpu":
42
  model = torch.quantization.quantize_dynamic(
43
  model, {torch.nn.Linear}, dtype=torch.qint8
@@ -46,7 +46,7 @@ def load_model(quantize=quantize_model):
46
  model.eval()
47
  return processor, model, device
48
  except Exception as e:
49
- st.error(f"Failed to load model: {str(e)}. Please check your internet connection or clear the cache (~/.cache/huggingface/hub).")
50
  raise
51
 
52
  processor, model, device = load_model()
@@ -57,7 +57,7 @@ def resize_frame(frame, width=None):
57
  return frame
58
  aspect_ratio = frame.shape[1] / frame.shape[0]
59
  height = int(width / aspect_ratio)
60
- return cv2.resize(frame, (width, height), interpolation=cv2.INTER_AREA)
61
 
62
  # Function to process a batch of frames
63
  async def detect_faults_batch(frames, processor, model, device):
@@ -162,7 +162,7 @@ async def process_video(video_path, frame_skip, batch_size):
162
  frames_batch = []
163
  processed_frames += batch_size
164
  progress.progress(min(processed_frames / total_frames, 1.0))
165
-
166
  frame_count += 1
167
 
168
  if frames_batch:
@@ -200,8 +200,13 @@ if uploaded_file is not None:
200
 
201
  st.video(tfile.name, format="video/mp4")
202
 
203
- loop = asyncio.get_event_loop()
204
- output_path, video_faults = loop.run_until_complete(process_video(tfile.name, frame_skip, batch_size))
 
 
 
 
 
205
 
206
  if output_path and video_faults:
207
  st.subheader("Fault Detection Results")
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import DetrImageProcessor, DetrForObjectDetection
4
  import cv2
5
  import numpy as np
6
  import tempfile
 
19
 
20
  # UI controls for optimization parameters
21
  st.sidebar.header("Analysis Settings")
22
+ frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=50, value=30)
23
+ batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=32, value=16 if torch.cuda.is_available() else 8)
24
  resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True)
25
+ resize_width = 512 if resize_enabled else None
26
+ quantize_model = st.sidebar.checkbox("Quantize Model (faster, esp. on CPU)", value=True)
27
 
28
  # Load model and processor
29
  @st.cache_resource
 
32
  logging.set_verbosity_error()
33
 
34
  try:
35
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
36
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  model.to(device)
39
 
40
+ # Apply dynamic quantization if enabled
41
  if quantize and device.type == "cpu":
42
  model = torch.quantization.quantize_dynamic(
43
  model, {torch.nn.Linear}, dtype=torch.qint8
 
46
  model.eval()
47
  return processor, model, device
48
  except Exception as e:
49
+ st.error(f"Failed to load model: {str(e)}. Check internet connection or cache (~/.cache/huggingface/hub).")
50
  raise
51
 
52
  processor, model, device = load_model()
 
57
  return frame
58
  aspect_ratio = frame.shape[1] / frame.shape[0]
59
  height = int(width / aspect_ratio)
60
+ return cv2.resize(frame, (width, height), interpolation=cv2.INTER_LINEAR)
61
 
62
  # Function to process a batch of frames
63
  async def detect_faults_batch(frames, processor, model, device):
 
162
  frames_batch = []
163
  processed_frames += batch_size
164
  progress.progress(min(processed_frames / total_frames, 1.0))
165
+ GROUP BY
166
  frame_count += 1
167
 
168
  if frames_batch:
 
200
 
201
  st.video(tfile.name, format="video/mp4")
202
 
203
+ # Create a new event loop for Streamlit's ScriptRunner thread
204
+ loop = asyncio.new_event_loop()
205
+ asyncio.set_event_loop(loop)
206
+ try:
207
+ output_path, video_faults = loop.run_until_complete(process_video(tfile.name, frame_skip, batch_size))
208
+ finally:
209
+ loop.close()
210
 
211
  if output_path and video_faults:
212
  st.subheader("Fault Detection Results")