Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,17 +5,14 @@ import onnxruntime as ort
|
|
5 |
from PIL import Image
|
6 |
import tempfile
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
#
|
15 |
-
|
16 |
-
0: (255, 0, 0), # Red for vehicles
|
17 |
-
1: (0, 255, 0) # Green for license plates
|
18 |
-
}
|
19 |
|
20 |
# Load the ONNX model
|
21 |
@st.cache_resource
|
@@ -44,6 +41,11 @@ def postprocess_results(output, original_shape, confidence_threshold=0.25, iou_t
|
|
44 |
else:
|
45 |
raise ValueError(f"Unexpected output type: {type(output)}")
|
46 |
|
|
|
|
|
|
|
|
|
|
|
47 |
if len(predictions.shape) == 4:
|
48 |
predictions = predictions.squeeze((0, 1))
|
49 |
elif len(predictions.shape) == 3:
|
@@ -54,6 +56,15 @@ def postprocess_results(output, original_shape, confidence_threshold=0.25, iou_t
|
|
54 |
scores = predictions[:, 4]
|
55 |
class_ids = predictions[:, 5]
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
# Filter by confidence
|
58 |
mask = scores > confidence_threshold
|
59 |
boxes = boxes[mask]
|
@@ -102,7 +113,8 @@ def process_image(image):
|
|
102 |
|
103 |
# Draw bounding boxes on the image
|
104 |
for x1, y1, x2, y2, score, class_id in results:
|
105 |
-
color
|
|
|
106 |
cv2.rectangle(orig_image, (x1, y1), (x2, y2), color, 2)
|
107 |
|
108 |
label = f"{CLASSES[class_id]}: {score:.2f}"
|
@@ -146,7 +158,6 @@ def process_video(video_path):
|
|
146 |
(width, height)
|
147 |
)
|
148 |
|
149 |
-
# Add progress bar for video processing
|
150 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
151 |
progress_bar = st.progress(0)
|
152 |
frame_count = 0
|
@@ -159,7 +170,6 @@ def process_video(video_path):
|
|
159 |
processed_frame = process_image(frame)
|
160 |
out.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
|
161 |
|
162 |
-
# Update progress bar
|
163 |
frame_count += 1
|
164 |
progress_bar.progress(frame_count / total_frames)
|
165 |
|
@@ -170,7 +180,7 @@ def process_video(video_path):
|
|
170 |
return temp_file.name
|
171 |
|
172 |
# Streamlit UI
|
173 |
-
st.title("
|
174 |
|
175 |
# Add confidence threshold slider
|
176 |
confidence_threshold = st.slider(
|
@@ -209,15 +219,16 @@ if uploaded_file is not None:
|
|
209 |
processed_video = process_video(tfile.name)
|
210 |
st.video(processed_video)
|
211 |
|
212 |
-
# Add legend
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
222 |
|
223 |
-
st.write("Upload an image or video to detect
|
|
|
5 |
from PIL import Image
|
6 |
import tempfile
|
7 |
|
8 |
+
# Dynamically assign colors to classes
|
9 |
+
def get_color(class_id):
|
10 |
+
"""Generate a color for any class ID"""
|
11 |
+
np.random.seed(class_id) # For consistent colors
|
12 |
+
return tuple(map(int, np.random.randint(0, 255, 3)))
|
13 |
+
|
14 |
+
# Class labels - will be populated dynamically
|
15 |
+
CLASSES = {}
|
|
|
|
|
|
|
16 |
|
17 |
# Load the ONNX model
|
18 |
@st.cache_resource
|
|
|
41 |
else:
|
42 |
raise ValueError(f"Unexpected output type: {type(output)}")
|
43 |
|
44 |
+
# Debug: Print the shape and first few entries of predictions
|
45 |
+
st.write(f"Debug - Predictions shape: {predictions.shape}")
|
46 |
+
if len(predictions) > 0:
|
47 |
+
st.write(f"Debug - First prediction entry: {predictions[0]}")
|
48 |
+
|
49 |
if len(predictions.shape) == 4:
|
50 |
predictions = predictions.squeeze((0, 1))
|
51 |
elif len(predictions.shape) == 3:
|
|
|
56 |
scores = predictions[:, 4]
|
57 |
class_ids = predictions[:, 5]
|
58 |
|
59 |
+
# Debug: Print unique class IDs
|
60 |
+
unique_classes = np.unique(class_ids)
|
61 |
+
st.write(f"Debug - Unique class IDs found: {unique_classes}")
|
62 |
+
|
63 |
+
# Update CLASSES dictionary with any new class IDs
|
64 |
+
for class_id in unique_classes:
|
65 |
+
if int(class_id) not in CLASSES:
|
66 |
+
CLASSES[int(class_id)] = f"Class_{int(class_id)}"
|
67 |
+
|
68 |
# Filter by confidence
|
69 |
mask = scores > confidence_threshold
|
70 |
boxes = boxes[mask]
|
|
|
113 |
|
114 |
# Draw bounding boxes on the image
|
115 |
for x1, y1, x2, y2, score, class_id in results:
|
116 |
+
# Get color dynamically
|
117 |
+
color = get_color(class_id)
|
118 |
cv2.rectangle(orig_image, (x1, y1), (x2, y2), color, 2)
|
119 |
|
120 |
label = f"{CLASSES[class_id]}: {score:.2f}"
|
|
|
158 |
(width, height)
|
159 |
)
|
160 |
|
|
|
161 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
162 |
progress_bar = st.progress(0)
|
163 |
frame_count = 0
|
|
|
170 |
processed_frame = process_image(frame)
|
171 |
out.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
|
172 |
|
|
|
173 |
frame_count += 1
|
174 |
progress_bar.progress(frame_count / total_frames)
|
175 |
|
|
|
180 |
return temp_file.name
|
181 |
|
182 |
# Streamlit UI
|
183 |
+
st.title("Object Detection")
|
184 |
|
185 |
# Add confidence threshold slider
|
186 |
confidence_threshold = st.slider(
|
|
|
219 |
processed_video = process_video(tfile.name)
|
220 |
st.video(processed_video)
|
221 |
|
222 |
+
# Add legend after processing to include all detected classes
|
223 |
+
if CLASSES:
|
224 |
+
st.markdown("### Detection Legend")
|
225 |
+
for class_id, class_name in CLASSES.items():
|
226 |
+
color = get_color(class_id)
|
227 |
+
st.markdown(
|
228 |
+
f'<div style="display: flex; align-items: center;">'
|
229 |
+
f'<div style="width: 20px; height: 20px; background-color: rgb{color}; margin-right: 10px;"></div>'
|
230 |
+
f'<span>{class_name}</span></div>',
|
231 |
+
unsafe_allow_html=True
|
232 |
+
)
|
233 |
|
234 |
+
st.write("Upload an image or video to detect objects.")
|