🚸 [Add] try-except in loading cache files
Browse files- yolo/config/task/validation.yaml +1 -1
- yolo/tools/data_loader.py +9 -1
- yolo/tools/solver.py +2 -2
- yolo/utils/dataset_utils.py +8 -1
- yolo/utils/model_utils.py +2 -0
yolo/config/task/validation.yaml
CHANGED
@@ -7,7 +7,7 @@ data:
|
|
7 |
shuffle: False
|
8 |
pin_memory: True
|
9 |
data_augment: {}
|
10 |
-
dynamic_shape:
|
11 |
nms:
|
12 |
min_confidence: 0.0001
|
13 |
min_iou: 0.7
|
|
|
7 |
shuffle: False
|
8 |
pin_memory: True
|
9 |
data_augment: {}
|
10 |
+
dynamic_shape: False
|
11 |
nms:
|
12 |
min_confidence: 0.0001
|
13 |
min_iou: 0.7
|
yolo/tools/data_loader.py
CHANGED
@@ -56,7 +56,15 @@ class YoloDataset(Dataset):
|
|
56 |
data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
|
57 |
torch.save(data, cache_path)
|
58 |
else:
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
logger.info(f":package: Loaded {phase_name} cache")
|
61 |
return data
|
62 |
|
|
|
56 |
data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
|
57 |
torch.save(data, cache_path)
|
58 |
else:
|
59 |
+
try:
|
60 |
+
data = torch.load(cache_path, weights_only=False)
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(
|
63 |
+
f":rotating_light: Failed to load the cache at '{cache_path}'.\n"
|
64 |
+
":rotating_light: This may be caused by using cache from different other YOLO.\n"
|
65 |
+
":rotating_light: Please clean the cache and try running again."
|
66 |
+
)
|
67 |
+
raise e
|
68 |
logger.info(f":package: Loaded {phase_name} cache")
|
69 |
return data
|
70 |
|
yolo/tools/solver.py
CHANGED
@@ -56,7 +56,6 @@ class ValidateModel(BaseModel):
|
|
56 |
"map": batch_metrics["map"],
|
57 |
"map_50": batch_metrics["map_50"],
|
58 |
},
|
59 |
-
on_step=True,
|
60 |
batch_size=batch_size,
|
61 |
)
|
62 |
return predicts
|
@@ -102,9 +101,10 @@ class TrainModel(ValidateModel):
|
|
102 |
prog_bar=True,
|
103 |
on_epoch=True,
|
104 |
batch_size=batch_size,
|
|
|
105 |
rank_zero_only=True,
|
106 |
)
|
107 |
-
self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, rank_zero_only=True)
|
108 |
return loss * batch_size
|
109 |
|
110 |
def configure_optimizers(self):
|
|
|
56 |
"map": batch_metrics["map"],
|
57 |
"map_50": batch_metrics["map_50"],
|
58 |
},
|
|
|
59 |
batch_size=batch_size,
|
60 |
)
|
61 |
return predicts
|
|
|
101 |
prog_bar=True,
|
102 |
on_epoch=True,
|
103 |
batch_size=batch_size,
|
104 |
+
sync_dist=True,
|
105 |
rank_zero_only=True,
|
106 |
)
|
107 |
+
self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, sync_dist=True, rank_zero_only=True)
|
108 |
return loss * batch_size
|
109 |
|
110 |
def configure_optimizers(self):
|
yolo/utils/dataset_utils.py
CHANGED
@@ -115,7 +115,14 @@ def scale_segmentation(
|
|
115 |
|
116 |
|
117 |
def tensorlize(data):
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
max_box = max(bbox.size(0) for bbox in bboxes)
|
120 |
padded_bbox_list = []
|
121 |
for bbox in bboxes:
|
|
|
115 |
|
116 |
|
117 |
def tensorlize(data):
|
118 |
+
try:
|
119 |
+
img_paths, bboxes, img_ratios = zip(*data)
|
120 |
+
except ValueError as e:
|
121 |
+
logger.error(
|
122 |
+
":rotating_light: This may be caused by using old cache or another version of YOLO's cache.\n"
|
123 |
+
":rotating_light: Please clean the cache and try running again."
|
124 |
+
)
|
125 |
+
raise e
|
126 |
max_box = max(bbox.size(0) for bbox in bboxes)
|
127 |
padded_bbox_list = []
|
128 |
for bbox in bboxes:
|
yolo/utils/model_utils.py
CHANGED
@@ -47,6 +47,8 @@ class EMA(Callback):
|
|
47 |
def setup(self, trainer, pl_module, stage):
|
48 |
pl_module.ema = deepcopy(pl_module.model)
|
49 |
self.ema_parameters = [param.clone().detach().to(pl_module.device) for param in pl_module.parameters()]
|
|
|
|
|
50 |
|
51 |
def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
|
52 |
for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters):
|
|
|
47 |
def setup(self, trainer, pl_module, stage):
|
48 |
pl_module.ema = deepcopy(pl_module.model)
|
49 |
self.ema_parameters = [param.clone().detach().to(pl_module.device) for param in pl_module.parameters()]
|
50 |
+
for param in pl_module.ema.parameters():
|
51 |
+
param.requires_grad = False
|
52 |
|
53 |
def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
|
54 |
for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters):
|