♻️ [Refactor] move to_tensor to transform
Browse files- yolo/__init__.py +23 -0
- yolo/tools/data_augmentation.py +1 -0
- yolo/tools/data_loader.py +2 -4
yolo/__init__.py
CHANGED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yolo.config.config import Config
|
2 |
+
from yolo.model.yolo import create_model
|
3 |
+
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
4 |
+
from yolo.tools.drawer import draw_bboxes
|
5 |
+
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
6 |
+
from yolo.utils.bounding_box_utils import bbox_nms
|
7 |
+
from yolo.utils.deploy_utils import FastModelLoader
|
8 |
+
from yolo.utils.logging_utils import custom_logger
|
9 |
+
|
10 |
+
all = [
|
11 |
+
"create_model",
|
12 |
+
"Config",
|
13 |
+
"custom_logger",
|
14 |
+
"validate_log_directory",
|
15 |
+
"draw_bboxes",
|
16 |
+
"bbox_nms",
|
17 |
+
"AugmentationComposer",
|
18 |
+
"create_dataloader",
|
19 |
+
"FastModelLoader",
|
20 |
+
"ModelTester",
|
21 |
+
"ModelTrainer",
|
22 |
+
"ModelValidator",
|
23 |
+
]
|
yolo/tools/data_augmentation.py
CHANGED
@@ -21,6 +21,7 @@ class AugmentationComposer:
|
|
21 |
for transform in self.transforms:
|
22 |
image, boxes = transform(image, boxes)
|
23 |
image, boxes = self.pad_resize(image, boxes)
|
|
|
24 |
return image, boxes
|
25 |
|
26 |
|
|
|
21 |
for transform in self.transforms:
|
22 |
image, boxes = transform(image, boxes)
|
23 |
image, boxes = self.pad_resize(image, boxes)
|
24 |
+
image = TF.to_tensor(image)
|
25 |
return image, boxes
|
26 |
|
27 |
|
yolo/tools/data_loader.py
CHANGED
@@ -151,9 +151,7 @@ class YoloDataset(Dataset):
|
|
151 |
|
152 |
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
|
153 |
img, bboxes = self.get_data(idx)
|
154 |
-
|
155 |
-
img, bboxes = self.transform(img, bboxes)
|
156 |
-
img = TF.to_tensor(img)
|
157 |
return img, bboxes
|
158 |
|
159 |
def __len__(self) -> int:
|
@@ -269,7 +267,7 @@ class StreamDataLoader:
|
|
269 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
270 |
frame = Image.fromarray(frame)
|
271 |
frame, _ = self.transform(frame, torch.zeros(0, 5))
|
272 |
-
frame =
|
273 |
if not self.is_stream:
|
274 |
self.queue.put(frame)
|
275 |
else:
|
|
|
151 |
|
152 |
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
|
153 |
img, bboxes = self.get_data(idx)
|
154 |
+
img, bboxes = self.transform(img, bboxes)
|
|
|
|
|
155 |
return img, bboxes
|
156 |
|
157 |
def __len__(self) -> int:
|
|
|
267 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
268 |
frame = Image.fromarray(frame)
|
269 |
frame, _ = self.transform(frame, torch.zeros(0, 5))
|
270 |
+
frame = frame[None]
|
271 |
if not self.is_stream:
|
272 |
self.queue.put(frame)
|
273 |
else:
|