tom-b974 commited on
Commit
61985ff
·
verified ·
1 Parent(s): 7dfd446

implementation of inference

Browse files
Files changed (1) hide show
  1. tasks/image.py +21 -21
tasks/image.py CHANGED
@@ -6,6 +6,7 @@ from sklearn.metrics import accuracy_score
6
  import random
7
  import os
8
 
 
9
  from .utils.evaluation import ImageEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
@@ -14,9 +15,11 @@ load_dotenv()
14
 
15
  router = APIRouter()
16
 
17
- DESCRIPTION = "Random Baseline"
18
  ROUTE = "/image"
19
 
 
 
20
  def parse_boxes(annotation_string):
21
  """Parse multiple boxes from a single annotation string.
22
  Each box has 5 values: class_id, x_center, y_center, width, height"""
@@ -99,37 +102,34 @@ async def evaluate_image(request: ImageEvaluationRequest):
99
  # YOUR MODEL INFERENCE CODE HERE
100
  # Update the code below to replace the random baseline with your model inference
101
  #--------------------------------------------------------------------------------------------
102
-
103
  predictions = []
104
  true_labels = []
105
  pred_boxes = []
106
- true_boxes_list = [] # List of lists, each inner list contains boxes for one image
107
-
108
  for example in test_dataset:
109
- # Parse true annotation (YOLO format: class_id x_center y_center width height)
110
  annotation = example.get("annotations", "").strip()
111
  has_smoke = len(annotation) > 0
112
  true_labels.append(int(has_smoke))
113
 
114
- # Make random classification prediction
115
- pred_has_smoke = random.random() > 0.5
116
- predictions.append(int(pred_has_smoke))
117
-
118
- # If there's a true box, parse it and make random box prediction
119
  if has_smoke:
120
- # Parse all true boxes from the annotation
121
  image_true_boxes = parse_boxes(annotation)
122
  true_boxes_list.append(image_true_boxes)
123
-
124
- # For baseline, make one random box prediction per image
125
- # In a real model, you might want to predict multiple boxes
126
- random_box = [
127
- random.random(), # x_center
128
- random.random(), # y_center
129
- random.random() * 0.5, # width (max 0.5)
130
- random.random() * 0.5 # height (max 0.5)
131
- ]
132
- pred_boxes.append(random_box)
 
 
 
 
133
 
134
  #--------------------------------------------------------------------------------------------
135
  # YOUR MODEL INFERENCE STOPS HERE
 
6
  import random
7
  import os
8
 
9
+ from ultralytics import YOLO # Import YOLO
10
  from .utils.evaluation import ImageEvaluationRequest
11
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
12
 
 
15
 
16
  router = APIRouter()
17
 
18
+ DESCRIPTION = "YOLO Smoke Detection"
19
  ROUTE = "/image"
20
 
21
+ yolo_model = YOLO("models/best.pt")
22
+
23
  def parse_boxes(annotation_string):
24
  """Parse multiple boxes from a single annotation string.
25
  Each box has 5 values: class_id, x_center, y_center, width, height"""
 
102
  # YOUR MODEL INFERENCE CODE HERE
103
  # Update the code below to replace the random baseline with your model inference
104
  #--------------------------------------------------------------------------------------------
 
105
  predictions = []
106
  true_labels = []
107
  pred_boxes = []
108
+ true_boxes_list = []
109
+
110
  for example in test_dataset:
111
+ image = example["image"]
112
  annotation = example.get("annotations", "").strip()
113
  has_smoke = len(annotation) > 0
114
  true_labels.append(int(has_smoke))
115
 
 
 
 
 
 
116
  if has_smoke:
 
117
  image_true_boxes = parse_boxes(annotation)
118
  true_boxes_list.append(image_true_boxes)
119
+
120
+ # Perform YOLO inference
121
+ results = yolo_model.predict(image)
122
+ if len(results[0].boxes): # If predictions exist
123
+ pred_box = results[0].boxes.xywh[0].cpu().numpy().tolist() # First box in YOLO format
124
+ pred_boxes.append(pred_box)
125
+ else:
126
+ pred_boxes.append([]) # No prediction for this image
127
+ else:
128
+ true_boxes_list.append([])
129
+ pred_boxes.append([])
130
+
131
+ # Classification: If predictions exist, assume smoke is present
132
+ predictions.append(1 if len(results[0].boxes) > 0 else 0)
133
 
134
  #--------------------------------------------------------------------------------------------
135
  # YOUR MODEL INFERENCE STOPS HERE