Sanjayraju30 commited on
Commit
b12d8ba
·
verified ·
1 Parent(s): 4c8c110

Update services/services/fault_service.py

Browse files
Files changed (1) hide show
  1. services/services/fault_service.py +29 -7
services/services/fault_service.py CHANGED
@@ -1,11 +1,33 @@
1
  from ultralytics import YOLO
 
 
2
 
3
- fault_model = YOLO("pole_fault_model.pt")
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def detect_pole_faults(image_path):
6
- results = fault_model(image_path)
7
- flagged = []
8
- for r in results:
9
- if hasattr(r, 'fault_type') and r.conf > 0.6:
10
- flagged.append({"fault_type": r.fault_type, "confidence": r.conf})
11
- return flagged
 
 
 
 
 
 
 
 
 
1
  from ultralytics import YOLO
2
+ import os
3
+ import torch.serialization
4
 
5
+ # Allowlist Ultralytics DetectionModel to avoid UnpicklingError
6
+ torch.serialization.add_safe_globals(['ultralytics.nn.tasks.DetectionModel'])
7
+
8
+ def load_fault_model():
9
+ model_path = "pole_fault_model.pt"
10
+ if os.path.exists(model_path):
11
+ print(f"Loading custom model: {model_path}")
12
+ return YOLO(model_path)
13
+ else:
14
+ print(f"Warning: {model_path} not found. Falling back to YOLOv8s.")
15
+ return YOLO("yolov8s.pt") # Fallback to pre-trained YOLOv8s
16
+
17
+ fault_model = load_fault_model()
18
 
19
  def detect_pole_faults(image_path):
20
+ try:
21
+ results = fault_model(image_path)
22
+ flagged = []
23
+ for r in results:
24
+ # Check if model is custom-trained with fault_type (for custom models)
25
+ if hasattr(r, 'fault_type') and r.conf > 0.6:
26
+ flagged.append({"fault_type": r.fault_type, "confidence": r.conf})
27
+ # Fallback for generic YOLOv8 models (no fault_type)
28
+ elif r.names and r.conf > 0.6:
29
+ flagged.append({"fault_type": r.names[int(r.cls)], "confidence": r.conf})
30
+ return flagged
31
+ except Exception as e:
32
+ print(f"Error in fault detection: {e}")
33
+ return []