File size: 4,751 Bytes
e9629ef
d8aafaa
 
e9629ef
 
 
 
 
 
 
d8aafaa
 
 
e9629ef
 
 
d8aafaa
e9629ef
 
 
 
 
 
d8aafaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9629ef
 
d8aafaa
 
e9629ef
 
d8aafaa
 
e9629ef
 
d8aafaa
 
 
 
 
 
 
 
 
 
 
e9629ef
d8aafaa
 
 
e9629ef
 
d8aafaa
e9629ef
 
d8aafaa
e9629ef
d8aafaa
 
 
 
 
e9629ef
d8aafaa
e9629ef
 
d8aafaa
 
 
 
 
 
 
 
 
 
e9629ef
 
 
d8aafaa
 
 
 
 
 
 
 
 
 
 
 
e9629ef
 
d8aafaa
e9629ef
 
d8aafaa
 
e9629ef
 
d8aafaa
 
e9629ef
 
d8aafaa
e9629ef
d8aafaa
 
 
e9629ef
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from PIL import Image
from os import path, listdir

import hydra
import numpy as np
import torch
from torch.utils.data import Dataset
from loguru import logger
from tqdm.rich import tqdm
import diskcache as dc
from typing import Union
from drawer import draw_bboxes
from dataargument import Compose, RandomHorizontalFlip


class YoloDataset(Dataset):
    def __init__(self, dataset_cfg: dict, phase: str = "train", transform=None):
        phase_name = dataset_cfg.get(phase, phase)

        self.transform = transform
        self.data = self.load_data(dataset_cfg.path, phase_name)

    def load_data(self, dataset_path, phase_name):
        """
        Loads data from a cache or generates a new cache for a specific dataset phase.

        Parameters:
            dataset_path (str): The root path to the dataset directory.
            phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.

        Returns:
            dict: The loaded data from the cache for the specified phase.
        """
        cache_path = path.join(dataset_path, ".cache")
        cache = dc.Cache(cache_path)
        data = cache.get(phase_name)

        if data is None:
            logger.info("Generating {} cache", phase_name)
            images_path = path.join(dataset_path, phase_name, "images")
            labels_path = path.join(dataset_path, phase_name, "labels")
            data = self.filter_data(images_path, labels_path)
            cache[phase_name] = data

        cache.close()
        logger.info("Loaded {} cache", phase_name)
        data = cache[phase_name]
        return data

    def filter_data(self, images_path: str, labels_path: str) -> list:
        """
        Filters and collects dataset information by pairing images with their corresponding labels.

        Parameters:
            images_path (str): Path to the directory containing image files.
            labels_path (str): Path to the directory containing label files.

        Returns:
            list: A list of tuples, each containing the path to an image file and its associated labels as a tensor.
        """
        data = []
        valid_inputs = 0
        images_list = sorted(listdir(images_path))
        for image_name in tqdm(images_list, desc="Filtering data"):
            if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
                continue

            img_path = path.join(images_path, image_name)
            base_name, _ = path.splitext(image_name)
            label_path = path.join(labels_path, f"{base_name}.txt")

            if path.isfile(label_path):
                labels = self.load_valid_labels(label_path)
                if labels is not None:
                    data.append((img_path, labels))
                    valid_inputs += 1

        logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
        return data

    def load_valid_labels(self, label_path: str) -> Union[torch.Tensor, None]:
        """
        Loads and validates bounding box data is [0, 1] from a label file.

        Parameters:
            label_path (str): The filepath to the label file containing bounding box data.

        Returns:
            torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
        """
        bboxes = []
        with open(label_path, "r") as file:
            for line in file:
                parts = list(map(float, line.strip().split()))
                cls = parts[0]
                points = np.array(parts[1:]).reshape(-1, 2)
                valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
                if valid_points.size > 1:
                    bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
                    bboxes.append(bbox)

        if bboxes:
            return torch.stack(bboxes)
        else:
            logger.warning("No valid BBox in {}", label_path)
            return None

    def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
        img_path, bboxes = self.data[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img, bboxes = self.transform(img, bboxes)
        return img, bboxes

    def __len__(self) -> int:
        return len(self.data)


@hydra.main(config_path="../config", config_name="config", version_base=None)
def main(cfg):
    transform = Compose([eval(aug)(prob) for aug, prob in cfg.augmentation.items()])
    dataset = YoloDataset(cfg.data, transform=transform)
    draw_bboxes(*dataset[0])


if __name__ == "__main__":
    import sys

    sys.path.append("./")
    from tools.log_helper import custom_logger

    custom_logger()
    main()