from PIL import Image from os import path, listdir import hydra import numpy as np import torch from torch.utils.data import Dataset from loguru import logger from tqdm.rich import tqdm import diskcache as dc from typing import Union from drawer import draw_bboxes from data_augment import Compose, RandomHorizontalFlip, RandomVerticalFlip, Mosaic, MixUp class YoloDataset(Dataset): def __init__(self, dataset_cfg: dict, phase: str = "train", image_size: int = 640, transform=None): phase_name = dataset_cfg.get(phase, phase) self.image_size = image_size self.transform = transform self.transform.get_more_data = self.get_more_data self.transform.image_size = self.image_size self.data = self.load_data(dataset_cfg.path, phase_name) def load_data(self, dataset_path, phase_name): """ Loads data from a cache or generates a new cache for a specific dataset phase. Parameters: dataset_path (str): The root path to the dataset directory. phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for. Returns: dict: The loaded data from the cache for the specified phase. """ cache_path = path.join(dataset_path, ".cache") cache = dc.Cache(cache_path) data = cache.get(phase_name) if data is None: logger.info("Generating {} cache", phase_name) images_path = path.join(dataset_path, phase_name, "images") labels_path = path.join(dataset_path, phase_name, "labels") data = self.filter_data(images_path, labels_path) cache[phase_name] = data cache.close() logger.info("Loaded {} cache", phase_name) data = cache[phase_name] return data def filter_data(self, images_path: str, labels_path: str) -> list: """ Filters and collects dataset information by pairing images with their corresponding labels. Parameters: images_path (str): Path to the directory containing image files. labels_path (str): Path to the directory containing label files. Returns: list: A list of tuples, each containing the path to an image file and its associated labels as a tensor. """ data = [] valid_inputs = 0 images_list = sorted(listdir(images_path)) for image_name in tqdm(images_list, desc="Filtering data"): if not image_name.lower().endswith((".jpg", ".jpeg", ".png")): continue img_path = path.join(images_path, image_name) base_name, _ = path.splitext(image_name) label_path = path.join(labels_path, f"{base_name}.txt") if path.isfile(label_path): labels = self.load_valid_labels(label_path) if labels is not None: data.append((img_path, labels)) valid_inputs += 1 logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list)) return data def load_valid_labels(self, label_path: str) -> Union[torch.Tensor, None]: """ Loads and validates bounding box data is [0, 1] from a label file. Parameters: label_path (str): The filepath to the label file containing bounding box data. Returns: torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None. """ bboxes = [] with open(label_path, "r") as file: for line in file: parts = list(map(float, line.strip().split())) cls = parts[0] points = np.array(parts[1:]).reshape(-1, 2) valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2) if valid_points.size > 1: bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)]) bboxes.append(bbox) if bboxes: return torch.stack(bboxes) else: logger.warning("No valid BBox in {}", label_path) return None def get_data(self, idx): img_path, bboxes = self.data[idx] img = Image.open(img_path).convert("RGB") return img, bboxes def get_more_data(self, num: int = 1): indices = torch.randint(0, len(self), (num,)) return [self.get_data(idx) for idx in indices] def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]: img, bboxes = self.get_data(idx) if self.transform: img, bboxes = self.transform(img, bboxes) return img, bboxes def __len__(self) -> int: return len(self.data) @hydra.main(config_path="../config", config_name="config", version_base=None) def main(cfg): transform = Compose([eval(aug)(prob) for aug, prob in cfg.augmentation.items()]) dataset = YoloDataset(cfg.data, transform=transform) draw_bboxes(*dataset[0]) if __name__ == "__main__": import sys sys.path.append("./") from tools.log_helper import custom_logger custom_logger() main()