Spaces:
Runtime error
Runtime error
Update tasks/image.py
Browse files- tasks/image.py +3 -6
tasks/image.py
CHANGED
@@ -163,7 +163,7 @@ def compute_max_iou(true_boxes, pred_box):
|
|
163 |
|
164 |
@router.post(ROUTE, tags=["Image Task"],
|
165 |
description=DESCRIPTION)
|
166 |
-
async def evaluate_image(model_path: str = "
|
167 |
"""
|
168 |
Evaluate image classification and object detection for forest fire smoke.
|
169 |
|
@@ -184,10 +184,7 @@ async def evaluate_image(model_path: str = "models/yolo11s_best.pt", request: Im
|
|
184 |
# Split dataset
|
185 |
train_test = dataset["train"]
|
186 |
test_dataset = dataset["val"]
|
187 |
-
|
188 |
-
model = YOLO(model_path, task="detect")
|
189 |
-
if("detr" in model_path):
|
190 |
-
model = RTDETR(model_path)
|
191 |
|
192 |
# Start tracking emissions
|
193 |
tracker.start()
|
@@ -203,7 +200,7 @@ async def evaluate_image(model_path: str = "models/yolo11s_best.pt", request: Im
|
|
203 |
pred_boxes = []
|
204 |
true_boxes_list = [] # List of lists, each inner list contains boxes for one image
|
205 |
|
206 |
-
for example in
|
207 |
# Parse true annotation (YOLO format: class_id x_center y_center width height)
|
208 |
annotation = example.get("annotations", "").strip()
|
209 |
has_smoke = len(annotation) > 0
|
|
|
163 |
|
164 |
@router.post(ROUTE, tags=["Image Task"],
|
165 |
description=DESCRIPTION)
|
166 |
+
async def evaluate_image(model_path: str = "models_v2/rt_detr_fp16.engine", request: ImageEvaluationRequest = ImageEvaluationRequest()):
|
167 |
"""
|
168 |
Evaluate image classification and object detection for forest fire smoke.
|
169 |
|
|
|
184 |
# Split dataset
|
185 |
train_test = dataset["train"]
|
186 |
test_dataset = dataset["val"]
|
187 |
+
model = RTDETR(model_path)
|
|
|
|
|
|
|
188 |
|
189 |
# Start tracking emissions
|
190 |
tracker.start()
|
|
|
200 |
pred_boxes = []
|
201 |
true_boxes_list = [] # List of lists, each inner list contains boxes for one image
|
202 |
|
203 |
+
for example in test_dataset:
|
204 |
# Parse true annotation (YOLO format: class_id x_center y_center width height)
|
205 |
annotation = example.get("annotations", "").strip()
|
206 |
has_smoke = len(annotation) > 0
|