Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
-
from transformers import
|
4 |
import cv2
|
5 |
import numpy as np
|
6 |
import tempfile
|
@@ -19,11 +19,11 @@ st.write("Upload a thermal video (MP4) to detect thermal, dust, and power genera
|
|
19 |
|
20 |
# UI controls for optimization parameters
|
21 |
st.sidebar.header("Analysis Settings")
|
22 |
-
frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=
|
23 |
-
batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=
|
24 |
resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True)
|
25 |
-
resize_width =
|
26 |
-
quantize_model = st.sidebar.checkbox("Quantize Model (faster on CPU)", value=
|
27 |
|
28 |
# Load model and processor
|
29 |
@st.cache_resource
|
@@ -32,12 +32,12 @@ def load_model(quantize=quantize_model):
|
|
32 |
logging.set_verbosity_error()
|
33 |
|
34 |
try:
|
35 |
-
processor =
|
36 |
-
model =
|
37 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
model.to(device)
|
39 |
|
40 |
-
# Apply dynamic quantization
|
41 |
if quantize and device.type == "cpu":
|
42 |
model = torch.quantization.quantize_dynamic(
|
43 |
model, {torch.nn.Linear}, dtype=torch.qint8
|
@@ -46,7 +46,7 @@ def load_model(quantize=quantize_model):
|
|
46 |
model.eval()
|
47 |
return processor, model, device
|
48 |
except Exception as e:
|
49 |
-
st.error(f"Failed to load model: {str(e)}.
|
50 |
raise
|
51 |
|
52 |
processor, model, device = load_model()
|
@@ -57,7 +57,7 @@ def resize_frame(frame, width=None):
|
|
57 |
return frame
|
58 |
aspect_ratio = frame.shape[1] / frame.shape[0]
|
59 |
height = int(width / aspect_ratio)
|
60 |
-
return cv2.resize(frame, (width, height), interpolation=cv2.
|
61 |
|
62 |
# Function to process a batch of frames
|
63 |
async def detect_faults_batch(frames, processor, model, device):
|
@@ -162,7 +162,7 @@ async def process_video(video_path, frame_skip, batch_size):
|
|
162 |
frames_batch = []
|
163 |
processed_frames += batch_size
|
164 |
progress.progress(min(processed_frames / total_frames, 1.0))
|
165 |
-
|
166 |
frame_count += 1
|
167 |
|
168 |
if frames_batch:
|
@@ -200,8 +200,13 @@ if uploaded_file is not None:
|
|
200 |
|
201 |
st.video(tfile.name, format="video/mp4")
|
202 |
|
203 |
-
loop
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
if output_path and video_faults:
|
207 |
st.subheader("Fault Detection Results")
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
+
from transformers import DetrImageProcessor, DetrForObjectDetection
|
4 |
import cv2
|
5 |
import numpy as np
|
6 |
import tempfile
|
|
|
19 |
|
20 |
# UI controls for optimization parameters
|
21 |
st.sidebar.header("Analysis Settings")
|
22 |
+
frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=50, value=30)
|
23 |
+
batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=32, value=16 if torch.cuda.is_available() else 8)
|
24 |
resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True)
|
25 |
+
resize_width = 512 if resize_enabled else None
|
26 |
+
quantize_model = st.sidebar.checkbox("Quantize Model (faster, esp. on CPU)", value=True)
|
27 |
|
28 |
# Load model and processor
|
29 |
@st.cache_resource
|
|
|
32 |
logging.set_verbosity_error()
|
33 |
|
34 |
try:
|
35 |
+
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
36 |
+
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
37 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
model.to(device)
|
39 |
|
40 |
+
# Apply dynamic quantization if enabled
|
41 |
if quantize and device.type == "cpu":
|
42 |
model = torch.quantization.quantize_dynamic(
|
43 |
model, {torch.nn.Linear}, dtype=torch.qint8
|
|
|
46 |
model.eval()
|
47 |
return processor, model, device
|
48 |
except Exception as e:
|
49 |
+
st.error(f"Failed to load model: {str(e)}. Check internet connection or cache (~/.cache/huggingface/hub).")
|
50 |
raise
|
51 |
|
52 |
processor, model, device = load_model()
|
|
|
57 |
return frame
|
58 |
aspect_ratio = frame.shape[1] / frame.shape[0]
|
59 |
height = int(width / aspect_ratio)
|
60 |
+
return cv2.resize(frame, (width, height), interpolation=cv2.INTER_LINEAR)
|
61 |
|
62 |
# Function to process a batch of frames
|
63 |
async def detect_faults_batch(frames, processor, model, device):
|
|
|
162 |
frames_batch = []
|
163 |
processed_frames += batch_size
|
164 |
progress.progress(min(processed_frames / total_frames, 1.0))
|
165 |
+
GROUP BY
|
166 |
frame_count += 1
|
167 |
|
168 |
if frames_batch:
|
|
|
200 |
|
201 |
st.video(tfile.name, format="video/mp4")
|
202 |
|
203 |
+
# Create a new event loop for Streamlit's ScriptRunner thread
|
204 |
+
loop = asyncio.new_event_loop()
|
205 |
+
asyncio.set_event_loop(loop)
|
206 |
+
try:
|
207 |
+
output_path, video_faults = loop.run_until_complete(process_video(tfile.name, frame_skip, batch_size))
|
208 |
+
finally:
|
209 |
+
loop.close()
|
210 |
|
211 |
if output_path and video_faults:
|
212 |
st.subheader("Fault Detection Results")
|