Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
-
from transformers import
|
4 |
import cv2
|
5 |
import numpy as np
|
6 |
import tempfile
|
@@ -19,23 +19,35 @@ st.write("Upload a thermal video (MP4) to detect thermal, dust, and power genera
|
|
19 |
|
20 |
# UI controls for optimization parameters
|
21 |
st.sidebar.header("Analysis Settings")
|
22 |
-
frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=
|
23 |
-
batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=16, value=
|
24 |
resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True)
|
25 |
resize_width = 640 if resize_enabled else None
|
|
|
26 |
|
27 |
# Load model and processor
|
28 |
@st.cache_resource
|
29 |
-
def load_model():
|
30 |
warnings.filterwarnings("ignore", message="Some weights of the model checkpoint.*were not used")
|
31 |
logging.set_verbosity_error()
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
processor, model, device = load_model()
|
41 |
|
@@ -50,7 +62,6 @@ def resize_frame(frame, width=None):
|
|
50 |
# Function to process a batch of frames
|
51 |
async def detect_faults_batch(frames, processor, model, device):
|
52 |
try:
|
53 |
-
# Resize frames if enabled
|
54 |
frames = [resize_frame(frame, resize_width) for frame in frames]
|
55 |
inputs = processor(images=frames, return_tensors="pt").to(device)
|
56 |
with torch.no_grad():
|
@@ -87,7 +98,6 @@ async def detect_faults_batch(frames, processor, model, device):
|
|
87 |
annotated_frames.append(annotated_frame)
|
88 |
all_faults.append(faults)
|
89 |
|
90 |
-
# Clear GPU memory
|
91 |
if torch.cuda.is_available():
|
92 |
torch.cuda.empty_cache()
|
93 |
|
@@ -109,7 +119,6 @@ async def process_video(video_path, frame_skip, batch_size):
|
|
109 |
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
110 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
111 |
|
112 |
-
# Adjust output size if resizing
|
113 |
out_width = resize_width if resize_width else frame_width
|
114 |
out_height = int(out_width * frame_height / frame_width) if resize_width else frame_height
|
115 |
|
@@ -132,7 +141,6 @@ async def process_video(video_path, frame_skip, batch_size):
|
|
132 |
break
|
133 |
|
134 |
if frame_count % frame_skip != 0:
|
135 |
-
# Resize frame for output if needed
|
136 |
frame = resize_frame(frame, resize_width)
|
137 |
out.write(frame)
|
138 |
frame_count += 1
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
+
from transformers import YolosImageProcessor, YolosForObjectDetection
|
4 |
import cv2
|
5 |
import numpy as np
|
6 |
import tempfile
|
|
|
19 |
|
20 |
# UI controls for optimization parameters
|
21 |
st.sidebar.header("Analysis Settings")
|
22 |
+
frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=30, value=15)
|
23 |
+
batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=16, value=12)
|
24 |
resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True)
|
25 |
resize_width = 640 if resize_enabled else None
|
26 |
+
quantize_model = st.sidebar.checkbox("Quantize Model (faster on CPU)", value=False)
|
27 |
|
28 |
# Load model and processor
|
29 |
@st.cache_resource
|
30 |
+
def load_model(quantize=quantize_model):
|
31 |
warnings.filterwarnings("ignore", message="Some weights of the model checkpoint.*were not used")
|
32 |
logging.set_verbosity_error()
|
33 |
|
34 |
+
try:
|
35 |
+
processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny")
|
36 |
+
model = YolosForObjectDetection.from_pretrained("hustvl/yolos-tiny")
|
37 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
+
model.to(device)
|
39 |
+
|
40 |
+
# Apply dynamic quantization for CPU if enabled
|
41 |
+
if quantize and device.type == "cpu":
|
42 |
+
model = torch.quantization.quantize_dynamic(
|
43 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
44 |
+
)
|
45 |
+
|
46 |
+
model.eval()
|
47 |
+
return processor, model, device
|
48 |
+
except Exception as e:
|
49 |
+
st.error(f"Failed to load model: {str(e)}. Please check your internet connection or clear the cache (~/.cache/huggingface/hub).")
|
50 |
+
raise
|
51 |
|
52 |
processor, model, device = load_model()
|
53 |
|
|
|
62 |
# Function to process a batch of frames
|
63 |
async def detect_faults_batch(frames, processor, model, device):
|
64 |
try:
|
|
|
65 |
frames = [resize_frame(frame, resize_width) for frame in frames]
|
66 |
inputs = processor(images=frames, return_tensors="pt").to(device)
|
67 |
with torch.no_grad():
|
|
|
98 |
annotated_frames.append(annotated_frame)
|
99 |
all_faults.append(faults)
|
100 |
|
|
|
101 |
if torch.cuda.is_available():
|
102 |
torch.cuda.empty_cache()
|
103 |
|
|
|
119 |
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
120 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
121 |
|
|
|
122 |
out_width = resize_width if resize_width else frame_width
|
123 |
out_height = int(out_width * frame_height / frame_width) if resize_width else frame_height
|
124 |
|
|
|
141 |
break
|
142 |
|
143 |
if frame_count % frame_skip != 0:
|
|
|
144 |
frame = resize_frame(frame, resize_width)
|
145 |
out.write(frame)
|
146 |
frame_count += 1
|