π [Fix] PIL read image, manually close image
Browse files- yolo/tools/data_loader.py +2 -1
- yolo/tools/loss_functions.py +3 -3
- yolo/utils/bounding_box_utils.py +2 -1
- yolo/utils/logger.py +1 -1
- yolo/utils/logging_utils.py +0 -3
yolo/tools/data_loader.py
CHANGED
@@ -133,7 +133,8 @@ class YoloDataset(Dataset):
|
|
133 |
|
134 |
def get_data(self, idx):
|
135 |
img_path, bboxes = self.data[idx]
|
136 |
-
|
|
|
137 |
return img, bboxes, img_path
|
138 |
|
139 |
def get_more_data(self, num: int = 1):
|
|
|
133 |
|
134 |
def get_data(self, idx):
|
135 |
img_path, bboxes = self.data[idx]
|
136 |
+
with Image.open(img_path) as img:
|
137 |
+
img = img.convert("RGB")
|
138 |
return img, bboxes, img_path
|
139 |
|
140 |
def get_more_data(self, num: int = 1):
|
yolo/tools/loss_functions.py
CHANGED
@@ -129,9 +129,9 @@ class DualLoss:
|
|
129 |
self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
|
130 |
self.cls_rate * (aux_cls * self.aux_rate + main_cls),
|
131 |
]
|
132 |
-
loss_dict =
|
133 |
-
|
134 |
-
|
135 |
return sum(total_loss), loss_dict
|
136 |
|
137 |
|
|
|
129 |
self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
|
130 |
self.cls_rate * (aux_cls * self.aux_rate + main_cls),
|
131 |
]
|
132 |
+
loss_dict = {
|
133 |
+
f"Loss/{name}Loss": value.detach().item() for name, value in zip(["Box", "DFL", "BCE"], total_loss)
|
134 |
+
}
|
135 |
return sum(total_loss), loss_dict
|
136 |
|
137 |
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -69,7 +69,8 @@ def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
|
69 |
(bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
|
70 |
)
|
71 |
v = (4 / (math.pi**2)) * (arctan**2)
|
72 |
-
|
|
|
73 |
# Compute CIoU
|
74 |
ciou = diou - alpha * v
|
75 |
return ciou.to(dtype)
|
|
|
69 |
(bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
|
70 |
)
|
71 |
v = (4 / (math.pi**2)) * (arctan**2)
|
72 |
+
with torch.no_grad():
|
73 |
+
alpha = v / (v - iou + 1 + EPS)
|
74 |
# Compute CIoU
|
75 |
ciou = diou - alpha * v
|
76 |
return ciou.to(dtype)
|
yolo/utils/logger.py
CHANGED
@@ -3,7 +3,7 @@ import logging
|
|
3 |
from rich.console import Console
|
4 |
from rich.logging import RichHandler
|
5 |
|
6 |
-
logger = logging.getLogger(
|
7 |
logger.setLevel(logging.DEBUG)
|
8 |
logger.propagate = False
|
9 |
if not logger.hasHandlers():
|
|
|
3 |
from rich.console import Console
|
4 |
from rich.logging import RichHandler
|
5 |
|
6 |
+
logger = logging.getLogger("yolo")
|
7 |
logger.setLevel(logging.DEBUG)
|
8 |
logger.propagate = False
|
9 |
if not logger.hasHandlers():
|
yolo/utils/logging_utils.py
CHANGED
@@ -154,9 +154,6 @@ class YOLORichProgressBar(RichProgressBar):
|
|
154 |
|
155 |
|
156 |
class YOLORichModelSummary(RichModelSummary):
|
157 |
-
|
158 |
-
from typing_extensions import override
|
159 |
-
|
160 |
@staticmethod
|
161 |
@override
|
162 |
def summarize(
|
|
|
154 |
|
155 |
|
156 |
class YOLORichModelSummary(RichModelSummary):
|
|
|
|
|
|
|
157 |
@staticmethod
|
158 |
@override
|
159 |
def summarize(
|