arsath-sm commited on
Commit
9f0b3a7
1 Parent(s): 6907110

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -12,8 +12,8 @@ def load_model():
12
 
13
  ort_session = load_model()
14
 
15
- # Define class names (update this based on your model's classes)
16
- CLASS_NAMES = ['car', 'license_plate']
17
 
18
  def preprocess_image(image, target_size=(640, 640)):
19
  if isinstance(image, Image.Image):
@@ -57,9 +57,9 @@ def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_thre
57
  for i in indices:
58
  box = boxes[i]
59
  score = scores[i]
60
- class_id = class_ids[i]
61
  x1, y1, x2, y2 = map(int, box)
62
- results.append((x1, y1, x2, y2, float(score), int(class_id)))
63
 
64
  return results
65
 
@@ -73,9 +73,9 @@ def process_image(image):
73
  results = postprocess_results(outputs, image.shape)
74
 
75
  for x1, y1, x2, y2, score, class_id in results:
76
- color = (0, 255, 0) if CLASS_NAMES[class_id] == 'car' else (255, 0, 0)
77
  cv2.rectangle(orig_image, (x1, y1), (x2, y2), color, 2)
78
- label = f"{CLASS_NAMES[class_id]}: {score:.2f}"
79
  cv2.putText(orig_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
80
 
81
  return cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
 
12
 
13
  ort_session = load_model()
14
 
15
+ # Define class names and their corresponding indices
16
+ CLASS_NAMES = {0: 'car', 1: 'license_plate'}
17
 
18
  def preprocess_image(image, target_size=(640, 640)):
19
  if isinstance(image, Image.Image):
 
57
  for i in indices:
58
  box = boxes[i]
59
  score = scores[i]
60
+ class_id = int(class_ids[i])
61
  x1, y1, x2, y2 = map(int, box)
62
+ results.append((x1, y1, x2, y2, float(score), class_id))
63
 
64
  return results
65
 
 
73
  results = postprocess_results(outputs, image.shape)
74
 
75
  for x1, y1, x2, y2, score, class_id in results:
76
+ color = (0, 255, 0) if CLASS_NAMES.get(class_id, 'unknown') == 'car' else (255, 0, 0)
77
  cv2.rectangle(orig_image, (x1, y1), (x2, y2), color, 2)
78
+ label = f"{CLASS_NAMES.get(class_id, 'unknown')}: {score:.2f}"
79
  cv2.putText(orig_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
80
 
81
  return cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)