✨ [Finish] Dataloder and get_dataloader
Browse files- config/config.yaml +1 -0
- config/hyper/default.yaml +5 -0
- utils/dataloader.py +45 -11
- utils/drawer.py +21 -12
config/config.yaml
CHANGED
@@ -7,4 +7,5 @@ defaults:
|
|
7 |
- download: ../data/download
|
8 |
- augmentation: ../data/augmentation
|
9 |
- model: v7-base
|
|
|
10 |
- _self_
|
|
|
7 |
- download: ../data/download
|
8 |
- augmentation: ../data/augmentation
|
9 |
- model: v7-base
|
10 |
+
- hyper: default
|
11 |
- _self_
|
config/hyper/default.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
batch_size: 4
|
3 |
+
shuffle: True
|
4 |
+
num_workers: 4
|
5 |
+
pin_memory: True
|
utils/dataloader.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from os import listdir, path
|
2 |
-
from typing import Union
|
3 |
|
4 |
import diskcache as dc
|
5 |
import hydra
|
@@ -7,16 +7,11 @@ import numpy as np
|
|
7 |
import torch
|
8 |
from loguru import logger
|
9 |
from PIL import Image
|
10 |
-
from torch.utils.data import Dataset
|
|
|
11 |
from tqdm.rich import tqdm
|
12 |
|
13 |
-
from utils.data_augment import
|
14 |
-
Compose,
|
15 |
-
MixUp,
|
16 |
-
Mosaic,
|
17 |
-
RandomHorizontalFlip,
|
18 |
-
RandomVerticalFlip,
|
19 |
-
)
|
20 |
from utils.drawer import draw_bboxes
|
21 |
|
22 |
|
@@ -130,16 +125,55 @@ class YoloDataset(Dataset):
|
|
130 |
img, bboxes = self.get_data(idx)
|
131 |
if self.transform:
|
132 |
img, bboxes = self.transform(img, bboxes)
|
|
|
133 |
return img, bboxes
|
134 |
|
135 |
def __len__(self) -> int:
|
136 |
return len(self.data)
|
137 |
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
140 |
def main(cfg):
|
141 |
-
|
142 |
-
draw_bboxes(
|
143 |
|
144 |
|
145 |
if __name__ == "__main__":
|
|
|
1 |
from os import listdir, path
|
2 |
+
from typing import List, Tuple, Union
|
3 |
|
4 |
import diskcache as dc
|
5 |
import hydra
|
|
|
7 |
import torch
|
8 |
from loguru import logger
|
9 |
from PIL import Image
|
10 |
+
from torch.utils.data import DataLoader, Dataset
|
11 |
+
from torchvision.transforms import functional as TF
|
12 |
from tqdm.rich import tqdm
|
13 |
|
14 |
+
from utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
from utils.drawer import draw_bboxes
|
16 |
|
17 |
|
|
|
125 |
img, bboxes = self.get_data(idx)
|
126 |
if self.transform:
|
127 |
img, bboxes = self.transform(img, bboxes)
|
128 |
+
img = TF.to_tensor(img)
|
129 |
return img, bboxes
|
130 |
|
131 |
def __len__(self) -> int:
|
132 |
return len(self.data)
|
133 |
|
134 |
|
135 |
+
class YoloDataLoader(DataLoader):
|
136 |
+
def __init__(self, config: dict):
|
137 |
+
"""Initializes the YoloDataLoader with hydra-config files."""
|
138 |
+
hyper = config.hyper.data
|
139 |
+
dataset = YoloDataset(config)
|
140 |
+
|
141 |
+
super().__init__(
|
142 |
+
dataset,
|
143 |
+
batch_size=hyper.batch_size,
|
144 |
+
shuffle=hyper.shuffle,
|
145 |
+
num_workers=hyper.num_workers,
|
146 |
+
pin_memory=hyper.pin_memory,
|
147 |
+
collate_fn=self.collate_fn,
|
148 |
+
)
|
149 |
+
|
150 |
+
def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
151 |
+
"""
|
152 |
+
A collate function to handle batching of images and their corresponding targets.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
batch (list of tuples): Each tuple contains:
|
156 |
+
- image (torch.Tensor): The image tensor.
|
157 |
+
- labels (torch.Tensor): The tensor of labels for the image.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
Tuple[torch.Tensor, List[torch.Tensor]]: A tuple containing:
|
161 |
+
- A tensor of batched images.
|
162 |
+
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
163 |
+
"""
|
164 |
+
images = torch.stack([item[0] for item in batch])
|
165 |
+
targets = [item[1] for item in batch]
|
166 |
+
return images, targets
|
167 |
+
|
168 |
+
|
169 |
+
def get_dataloader(config):
|
170 |
+
return YoloDataLoader(config)
|
171 |
+
|
172 |
+
|
173 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
174 |
def main(cfg):
|
175 |
+
dataloader = get_dataloader(cfg)
|
176 |
+
draw_bboxes(next(iter(dataloader)))
|
177 |
|
178 |
|
179 |
if __name__ == "__main__":
|
utils/drawer.py
CHANGED
@@ -1,23 +1,31 @@
|
|
|
|
|
|
|
|
|
|
1 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
2 |
|
3 |
|
4 |
-
def draw_bboxes(img, bboxes):
|
5 |
"""
|
6 |
Draw bounding boxes on an image.
|
7 |
|
8 |
Args:
|
9 |
-
-
|
10 |
-
- bboxes (
|
|
|
11 |
"""
|
12 |
-
#
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
try:
|
17 |
-
font = ImageFont.truetype("arial.ttf", 30)
|
18 |
-
except IOError:
|
19 |
-
font = ImageFont.load_default(30)
|
20 |
width, height = img.size
|
|
|
21 |
|
22 |
for bbox in bboxes:
|
23 |
class_id, x_min, y_min, x_max, y_max = bbox
|
@@ -26,7 +34,8 @@ def draw_bboxes(img, bboxes):
|
|
26 |
y_min = y_min * height
|
27 |
y_max = y_max * height
|
28 |
shape = [(x_min, y_min), (x_max, y_max)]
|
29 |
-
draw.rectangle(shape, outline="red", width=
|
30 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
31 |
|
32 |
-
img.save("
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from loguru import logger
|
5 |
from PIL import Image, ImageDraw, ImageFont
|
6 |
+
from torchvision.transforms.functional import to_pil_image
|
7 |
|
8 |
|
9 |
+
def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
|
10 |
"""
|
11 |
Draw bounding boxes on an image.
|
12 |
|
13 |
Args:
|
14 |
+
- img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
|
15 |
+
- bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
|
16 |
+
where coordinates are normalized [0, 1].
|
17 |
"""
|
18 |
+
# Convert tensor image to PIL Image if necessary
|
19 |
+
if isinstance(img, torch.Tensor):
|
20 |
+
if img.dim() > 3:
|
21 |
+
logger.info("Multi-frame tensor detected, using the first image.")
|
22 |
+
img = img[0]
|
23 |
+
bboxes = bboxes[0]
|
24 |
+
img = to_pil_image(img)
|
25 |
|
26 |
+
draw = ImageDraw.Draw(img)
|
|
|
|
|
|
|
|
|
27 |
width, height = img.size
|
28 |
+
font = ImageFont.load_default(30)
|
29 |
|
30 |
for bbox in bboxes:
|
31 |
class_id, x_min, y_min, x_max, y_max = bbox
|
|
|
34 |
y_min = y_min * height
|
35 |
y_max = y_max * height
|
36 |
shape = [(x_min, y_min), (x_max, y_max)]
|
37 |
+
draw.rectangle(shape, outline="red", width=3)
|
38 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
39 |
|
40 |
+
img.save("visualize.jpg") # Save the image with annotations
|
41 |
+
logger.info("Saved visualize image at visualize.png")
|