arsath-sm commited on
Commit
de391b1
·
verified ·
1 Parent(s): 6af8008

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -26
app.py CHANGED
@@ -5,17 +5,14 @@ import onnxruntime as ort
5
  from PIL import Image
6
  import tempfile
7
 
8
- # Class labels for both vehicles and license plates
9
- CLASSES = {
10
- 0: "Vehicle",
11
- 1: "License_Plate"
12
- }
13
-
14
- # Different colors for each class
15
- COLORS = {
16
- 0: (255, 0, 0), # Red for vehicles
17
- 1: (0, 255, 0) # Green for license plates
18
- }
19
 
20
  # Load the ONNX model
21
  @st.cache_resource
@@ -44,6 +41,11 @@ def postprocess_results(output, original_shape, confidence_threshold=0.25, iou_t
44
  else:
45
  raise ValueError(f"Unexpected output type: {type(output)}")
46
 
 
 
 
 
 
47
  if len(predictions.shape) == 4:
48
  predictions = predictions.squeeze((0, 1))
49
  elif len(predictions.shape) == 3:
@@ -54,6 +56,15 @@ def postprocess_results(output, original_shape, confidence_threshold=0.25, iou_t
54
  scores = predictions[:, 4]
55
  class_ids = predictions[:, 5]
56
 
 
 
 
 
 
 
 
 
 
57
  # Filter by confidence
58
  mask = scores > confidence_threshold
59
  boxes = boxes[mask]
@@ -102,7 +113,8 @@ def process_image(image):
102
 
103
  # Draw bounding boxes on the image
104
  for x1, y1, x2, y2, score, class_id in results:
105
- color = COLORS[class_id]
 
106
  cv2.rectangle(orig_image, (x1, y1), (x2, y2), color, 2)
107
 
108
  label = f"{CLASSES[class_id]}: {score:.2f}"
@@ -146,7 +158,6 @@ def process_video(video_path):
146
  (width, height)
147
  )
148
 
149
- # Add progress bar for video processing
150
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
151
  progress_bar = st.progress(0)
152
  frame_count = 0
@@ -159,7 +170,6 @@ def process_video(video_path):
159
  processed_frame = process_image(frame)
160
  out.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
161
 
162
- # Update progress bar
163
  frame_count += 1
164
  progress_bar.progress(frame_count / total_frames)
165
 
@@ -170,7 +180,7 @@ def process_video(video_path):
170
  return temp_file.name
171
 
172
  # Streamlit UI
173
- st.title("Vehicle and License Plate Detection")
174
 
175
  # Add confidence threshold slider
176
  confidence_threshold = st.slider(
@@ -209,15 +219,16 @@ if uploaded_file is not None:
209
  processed_video = process_video(tfile.name)
210
  st.video(processed_video)
211
 
212
- # Add legend
213
- st.markdown("### Detection Legend")
214
- for class_id, class_name in CLASSES.items():
215
- color = COLORS[class_id]
216
- st.markdown(
217
- f'<div style="display: flex; align-items: center;">'
218
- f'<div style="width: 20px; height: 20px; background-color: rgb{color}; margin-right: 10px;"></div>'
219
- f'<span>{class_name}</span></div>',
220
- unsafe_allow_html=True
221
- )
 
222
 
223
- st.write("Upload an image or video to detect vehicles and license plates.")
 
5
  from PIL import Image
6
  import tempfile
7
 
8
+ # Dynamically assign colors to classes
9
+ def get_color(class_id):
10
+ """Generate a color for any class ID"""
11
+ np.random.seed(class_id) # For consistent colors
12
+ return tuple(map(int, np.random.randint(0, 255, 3)))
13
+
14
+ # Class labels - will be populated dynamically
15
+ CLASSES = {}
 
 
 
16
 
17
  # Load the ONNX model
18
  @st.cache_resource
 
41
  else:
42
  raise ValueError(f"Unexpected output type: {type(output)}")
43
 
44
+ # Debug: Print the shape and first few entries of predictions
45
+ st.write(f"Debug - Predictions shape: {predictions.shape}")
46
+ if len(predictions) > 0:
47
+ st.write(f"Debug - First prediction entry: {predictions[0]}")
48
+
49
  if len(predictions.shape) == 4:
50
  predictions = predictions.squeeze((0, 1))
51
  elif len(predictions.shape) == 3:
 
56
  scores = predictions[:, 4]
57
  class_ids = predictions[:, 5]
58
 
59
+ # Debug: Print unique class IDs
60
+ unique_classes = np.unique(class_ids)
61
+ st.write(f"Debug - Unique class IDs found: {unique_classes}")
62
+
63
+ # Update CLASSES dictionary with any new class IDs
64
+ for class_id in unique_classes:
65
+ if int(class_id) not in CLASSES:
66
+ CLASSES[int(class_id)] = f"Class_{int(class_id)}"
67
+
68
  # Filter by confidence
69
  mask = scores > confidence_threshold
70
  boxes = boxes[mask]
 
113
 
114
  # Draw bounding boxes on the image
115
  for x1, y1, x2, y2, score, class_id in results:
116
+ # Get color dynamically
117
+ color = get_color(class_id)
118
  cv2.rectangle(orig_image, (x1, y1), (x2, y2), color, 2)
119
 
120
  label = f"{CLASSES[class_id]}: {score:.2f}"
 
158
  (width, height)
159
  )
160
 
 
161
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
162
  progress_bar = st.progress(0)
163
  frame_count = 0
 
170
  processed_frame = process_image(frame)
171
  out.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
172
 
 
173
  frame_count += 1
174
  progress_bar.progress(frame_count / total_frames)
175
 
 
180
  return temp_file.name
181
 
182
  # Streamlit UI
183
+ st.title("Object Detection")
184
 
185
  # Add confidence threshold slider
186
  confidence_threshold = st.slider(
 
219
  processed_video = process_video(tfile.name)
220
  st.video(processed_video)
221
 
222
+ # Add legend after processing to include all detected classes
223
+ if CLASSES:
224
+ st.markdown("### Detection Legend")
225
+ for class_id, class_name in CLASSES.items():
226
+ color = get_color(class_id)
227
+ st.markdown(
228
+ f'<div style="display: flex; align-items: center;">'
229
+ f'<div style="width: 20px; height: 20px; background-color: rgb{color}; margin-right: 10px;"></div>'
230
+ f'<span>{class_name}</span></div>',
231
+ unsafe_allow_html=True
232
+ )
233
 
234
+ st.write("Upload an image or video to detect objects.")