henry000 commited on
Commit
7854350
·
1 Parent(s): e94b3ff

🔥 [Remove] utils and config again, move to yolo/

Browse files
config/config.py DELETED
@@ -1,72 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Dict, List, Union
3
-
4
-
5
- @dataclass
6
- class Model:
7
- anchor: List[List[int]]
8
- model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
9
-
10
-
11
- @dataclass
12
- class Download:
13
- auto: bool
14
- path: str
15
-
16
-
17
- @dataclass
18
- class DataLoaderConfig:
19
- batch_size: int
20
- shuffle: bool
21
- num_workers: int
22
- pin_memory: bool
23
-
24
-
25
- @dataclass
26
- class OptimizerArgs:
27
- lr: float
28
- weight_decay: float
29
-
30
-
31
- @dataclass
32
- class OptimizerConfig:
33
- type: str
34
- args: OptimizerArgs
35
-
36
-
37
- @dataclass
38
- class SchedulerArgs:
39
- step_size: int
40
- gamma: float
41
-
42
-
43
- @dataclass
44
- class SchedulerConfig:
45
- type: str
46
- args: SchedulerArgs
47
-
48
-
49
- @dataclass
50
- class EMAConfig:
51
- enabled: bool
52
- decay: float
53
-
54
-
55
- @dataclass
56
- class TrainConfig:
57
- optimizer: OptimizerConfig
58
- scheduler: SchedulerConfig
59
- ema: EMAConfig
60
-
61
-
62
- @dataclass
63
- class HyperConfig:
64
- data: DataLoaderConfig
65
- train: TrainConfig
66
-
67
-
68
- @dataclass
69
- class Config:
70
- model: Model
71
- download: Download
72
- hyper: HyperConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/config.yaml DELETED
@@ -1,11 +0,0 @@
1
- hydra:
2
- run:
3
- dir: ./runs
4
-
5
- defaults:
6
- - data: coco
7
- - download: ../data/download
8
- - augmentation: ../data/augmentation
9
- - model: v7-base
10
- - hyper: default
11
- - _self_
 
 
 
 
 
 
 
 
 
 
 
 
config/data/augmentation.yaml DELETED
@@ -1,3 +0,0 @@
1
- Mosaic: 1
2
- # MixUp: 1
3
- HorizontalFlip: 0.5
 
 
 
 
config/data/download.yaml DELETED
@@ -1,17 +0,0 @@
1
- auto: True
2
- path: data/coco
3
- images:
4
- base_url: http://images.cocodataset.org/zips/
5
- datasets:
6
- train:
7
- file_name: train2017.zip
8
- file_num: 118287
9
- val:
10
- file_name: val2017.zip
11
- num_files: 5000
12
- test:
13
- file_name: test2017.zip
14
- num_files: 40670
15
- hydra:
16
- run:
17
- dir: ./runs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/hyper/default.yaml DELETED
@@ -1,19 +0,0 @@
1
- data:
2
- batch_size: 4
3
- shuffle: True
4
- num_workers: 4
5
- pin_memory: True
6
- train:
7
- optimizer:
8
- type: Adam
9
- args:
10
- lr: 0.001
11
- weight_decay: 0.0001
12
- scheduler:
13
- type: StepLR
14
- args:
15
- step_size: 10
16
- gamma: 0.1
17
- ema:
18
- enabled: true
19
- decay: 0.995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/data_augment.py DELETED
@@ -1,125 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from PIL import Image
4
- from torchvision.transforms import functional as TF
5
-
6
-
7
- class Compose:
8
- """Composes several transforms together."""
9
-
10
- def __init__(self, transforms, image_size: int = 640):
11
- self.transforms = transforms
12
- self.image_size = image_size
13
-
14
- for transform in self.transforms:
15
- if hasattr(transform, "set_parent"):
16
- transform.set_parent(self)
17
-
18
- def __call__(self, image, boxes):
19
- for transform in self.transforms:
20
- image, boxes = transform(image, boxes)
21
- return image, boxes
22
-
23
-
24
- class HorizontalFlip:
25
- """Randomly horizontally flips the image along with the bounding boxes."""
26
-
27
- def __init__(self, prob=0.5):
28
- self.prob = prob
29
-
30
- def __call__(self, image, boxes):
31
- if torch.rand(1) < self.prob:
32
- image = TF.hflip(image)
33
- boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
34
- return image, boxes
35
-
36
-
37
- class VerticalFlip:
38
- """Randomly vertically flips the image along with the bounding boxes."""
39
-
40
- def __init__(self, prob=0.5):
41
- self.prob = prob
42
-
43
- def __call__(self, image, boxes):
44
- if torch.rand(1) < self.prob:
45
- image = TF.vflip(image)
46
- boxes[:, [2, 4]] = 1 - boxes[:, [4, 2]]
47
- return image, boxes
48
-
49
-
50
- class Mosaic:
51
- """Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""
52
-
53
- def __init__(self, prob=0.5):
54
- self.prob = prob
55
- self.parent = None
56
-
57
- def set_parent(self, parent):
58
- self.parent = parent
59
-
60
- def __call__(self, image, boxes):
61
- if torch.rand(1) >= self.prob:
62
- return image, boxes
63
-
64
- assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
65
-
66
- img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
67
- more_data = self.parent.get_more_data(3) # get 3 more images randomly
68
-
69
- data = [(image, boxes)] + more_data
70
- mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz))
71
- vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
72
- center = np.array([img_sz, img_sz])
73
- all_labels = []
74
-
75
- for (image, boxes), vector in zip(data, vectors):
76
- this_w, this_h = image.size
77
- coord = tuple(center + vector * np.array([this_w, this_h]))
78
-
79
- mosaic_image.paste(image, coord)
80
- xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
81
- xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
82
- xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
83
- ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
84
- ymax = (ymax * this_h + coord[1]) / (2 * img_sz)
85
-
86
- adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
87
- all_labels.append(adjusted_boxes)
88
-
89
- all_labels = torch.cat(all_labels, dim=0)
90
- mosaic_image = mosaic_image.resize((img_sz, img_sz))
91
- return mosaic_image, all_labels
92
-
93
-
94
- class MixUp:
95
- """Applies the MixUp augmentation to a pair of images and their corresponding boxes."""
96
-
97
- def __init__(self, prob=0.5, alpha=1.0):
98
- self.alpha = alpha
99
- self.prob = prob
100
- self.parent = None
101
-
102
- def set_parent(self, parent):
103
- """Set the parent dataset object for accessing dataset methods."""
104
- self.parent = parent
105
-
106
- def __call__(self, image, boxes):
107
- if torch.rand(1) >= self.prob:
108
- return image, boxes
109
-
110
- assert self.parent is not None, "Parent is not set. MixUp cannot retrieve additional data."
111
-
112
- # Retrieve another image and its boxes randomly from the dataset
113
- image2, boxes2 = self.parent.get_more_data()[0]
114
-
115
- # Calculate the mixup lambda parameter
116
- lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
117
-
118
- # Mix images
119
- image1, image2 = TF.to_tensor(image), TF.to_tensor(image2)
120
- mixed_image = lam * image1 + (1 - lam) * image2
121
-
122
- # Mix bounding boxes
123
- mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
124
-
125
- return TF.to_pil_image(mixed_image), mixed_boxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/dataloader.py DELETED
@@ -1,186 +0,0 @@
1
- from os import listdir, path
2
- from typing import List, Tuple, Union
3
-
4
- import diskcache as dc
5
- import hydra
6
- import numpy as np
7
- import torch
8
- from loguru import logger
9
- from PIL import Image
10
- from torch.utils.data import DataLoader, Dataset
11
- from torchvision.transforms import functional as TF
12
- from tqdm.rich import tqdm
13
-
14
- from utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
15
- from utils.drawer import draw_bboxes
16
-
17
-
18
- class YoloDataset(Dataset):
19
- def __init__(self, config: dict, phase: str = "train", image_size: int = 640):
20
- dataset_cfg = config.data
21
- augment_cfg = config.augmentation
22
- phase_name = dataset_cfg.get(phase, phase)
23
- self.image_size = image_size
24
-
25
- transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
26
- self.transform = Compose(transforms, self.image_size)
27
- self.transform.get_more_data = self.get_more_data
28
- self.data = self.load_data(dataset_cfg.path, phase_name)
29
-
30
- def load_data(self, dataset_path, phase_name):
31
- """
32
- Loads data from a cache or generates a new cache for a specific dataset phase.
33
-
34
- Parameters:
35
- dataset_path (str): The root path to the dataset directory.
36
- phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.
37
-
38
- Returns:
39
- dict: The loaded data from the cache for the specified phase.
40
- """
41
- cache_path = path.join(dataset_path, ".cache")
42
- cache = dc.Cache(cache_path)
43
- data = cache.get(phase_name)
44
-
45
- if data is None:
46
- logger.info("Generating {} cache", phase_name)
47
- images_path = path.join(dataset_path, phase_name, "images")
48
- labels_path = path.join(dataset_path, phase_name, "labels")
49
- data = self.filter_data(images_path, labels_path)
50
- cache[phase_name] = data
51
-
52
- cache.close()
53
- logger.info("📦 Loaded {} cache", phase_name)
54
- data = cache[phase_name]
55
- return data
56
-
57
- def filter_data(self, images_path: str, labels_path: str) -> list:
58
- """
59
- Filters and collects dataset information by pairing images with their corresponding labels.
60
-
61
- Parameters:
62
- images_path (str): Path to the directory containing image files.
63
- labels_path (str): Path to the directory containing label files.
64
-
65
- Returns:
66
- list: A list of tuples, each containing the path to an image file and its associated labels as a tensor.
67
- """
68
- data = []
69
- valid_inputs = 0
70
- images_list = sorted(listdir(images_path))
71
- for image_name in tqdm(images_list, desc="Filtering data"):
72
- if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
73
- continue
74
-
75
- img_path = path.join(images_path, image_name)
76
- base_name, _ = path.splitext(image_name)
77
- label_path = path.join(labels_path, f"{base_name}.txt")
78
-
79
- if path.isfile(label_path):
80
- labels = self.load_valid_labels(label_path)
81
- if labels is not None:
82
- data.append((img_path, labels))
83
- valid_inputs += 1
84
-
85
- logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
86
- return data
87
-
88
- def load_valid_labels(self, label_path: str) -> Union[torch.Tensor, None]:
89
- """
90
- Loads and validates bounding box data is [0, 1] from a label file.
91
-
92
- Parameters:
93
- label_path (str): The filepath to the label file containing bounding box data.
94
-
95
- Returns:
96
- torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
97
- """
98
- bboxes = []
99
- with open(label_path, "r") as file:
100
- for line in file:
101
- parts = list(map(float, line.strip().split()))
102
- cls = parts[0]
103
- points = np.array(parts[1:]).reshape(-1, 2)
104
- valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
105
- if valid_points.size > 1:
106
- bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
107
- bboxes.append(bbox)
108
-
109
- if bboxes:
110
- return torch.stack(bboxes)
111
- else:
112
- logger.warning("No valid BBox in {}", label_path)
113
- return None
114
-
115
- def get_data(self, idx):
116
- img_path, bboxes = self.data[idx]
117
- img = Image.open(img_path).convert("RGB")
118
- return img, bboxes
119
-
120
- def get_more_data(self, num: int = 1):
121
- indices = torch.randint(0, len(self), (num,))
122
- return [self.get_data(idx) for idx in indices]
123
-
124
- def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
125
- img, bboxes = self.get_data(idx)
126
- if self.transform:
127
- img, bboxes = self.transform(img, bboxes)
128
- img = TF.to_tensor(img)
129
- return img, bboxes
130
-
131
- def __len__(self) -> int:
132
- return len(self.data)
133
-
134
-
135
- class YoloDataLoader(DataLoader):
136
- def __init__(self, config: dict):
137
- """Initializes the YoloDataLoader with hydra-config files."""
138
- hyper = config.hyper.data
139
- dataset = YoloDataset(config)
140
-
141
- super().__init__(
142
- dataset,
143
- batch_size=hyper.batch_size,
144
- shuffle=hyper.shuffle,
145
- num_workers=hyper.num_workers,
146
- pin_memory=hyper.pin_memory,
147
- collate_fn=self.collate_fn,
148
- )
149
-
150
- def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
151
- """
152
- A collate function to handle batching of images and their corresponding targets.
153
-
154
- Args:
155
- batch (list of tuples): Each tuple contains:
156
- - image (torch.Tensor): The image tensor.
157
- - labels (torch.Tensor): The tensor of labels for the image.
158
-
159
- Returns:
160
- Tuple[torch.Tensor, List[torch.Tensor]]: A tuple containing:
161
- - A tensor of batched images.
162
- - A list of tensors, each corresponding to bboxes for each image in the batch.
163
- """
164
- images = torch.stack([item[0] for item in batch])
165
- targets = [item[1] for item in batch]
166
- return images, targets
167
-
168
-
169
- def get_dataloader(config):
170
- return YoloDataLoader(config)
171
-
172
-
173
- @hydra.main(config_path="../config", config_name="config", version_base=None)
174
- def main(cfg):
175
- dataloader = get_dataloader(cfg)
176
- draw_bboxes(next(iter(dataloader)))
177
-
178
-
179
- if __name__ == "__main__":
180
- import sys
181
-
182
- sys.path.append("./")
183
- from tools.log_helper import custom_logger
184
-
185
- custom_logger()
186
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/drawer.py DELETED
@@ -1,41 +0,0 @@
1
- from typing import List, Union
2
-
3
- import torch
4
- from loguru import logger
5
- from PIL import Image, ImageDraw, ImageFont
6
- from torchvision.transforms.functional import to_pil_image
7
-
8
-
9
- def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
10
- """
11
- Draw bounding boxes on an image.
12
-
13
- Args:
14
- - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
15
- - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
16
- where coordinates are normalized [0, 1].
17
- """
18
- # Convert tensor image to PIL Image if necessary
19
- if isinstance(img, torch.Tensor):
20
- if img.dim() > 3:
21
- logger.info("Multi-frame tensor detected, using the first image.")
22
- img = img[0]
23
- bboxes = bboxes[0]
24
- img = to_pil_image(img)
25
-
26
- draw = ImageDraw.Draw(img)
27
- width, height = img.size
28
- font = ImageFont.load_default(30)
29
-
30
- for bbox in bboxes:
31
- class_id, x_min, y_min, x_max, y_max = bbox
32
- x_min = x_min * width
33
- x_max = x_max * width
34
- y_min = y_min * height
35
- y_max = y_max * height
36
- shape = [(x_min, y_min), (x_max, y_max)]
37
- draw.rectangle(shape, outline="red", width=3)
38
- draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
39
-
40
- img.save("visualize.jpg") # Save the image with annotations
41
- logger.info("Saved visualize image at visualize.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/get_dataset.py DELETED
@@ -1,84 +0,0 @@
1
- import os
2
- import zipfile
3
-
4
- import hydra
5
- import requests
6
- from loguru import logger
7
- from tqdm.rich import tqdm
8
-
9
-
10
- def download_file(url, dest_path):
11
- """
12
- Downloads a file from a specified URL to a destination path with progress logging.
13
- """
14
- logger.info(f"Downloading {os.path.basename(dest_path)}...")
15
- with requests.get(url, stream=True) as r:
16
- r.raise_for_status()
17
- total_length = int(r.headers.get("content-length", 0))
18
- with open(dest_path, "wb") as f, tqdm(
19
- total=total_length, unit="iB", unit_scale=True, desc=os.path.basename(dest_path), leave=True
20
- ) as bar:
21
- for chunk in r.iter_content(chunk_size=1024 * 1024):
22
- f.write(chunk)
23
- bar.update(len(chunk))
24
- logger.info("Download complete!")
25
-
26
-
27
- def unzip_file(zip_path, extract_to):
28
- """
29
- Unzips a ZIP file to a specified directory.
30
- """
31
- logger.info(f"Unzipping {os.path.basename(zip_path)}...")
32
- with zipfile.ZipFile(zip_path, "r") as zip_ref:
33
- zip_ref.extractall(extract_to)
34
- os.remove(zip_path)
35
- logger.info(f"Removed {zip_path}")
36
-
37
-
38
- def check_files(directory, expected_count):
39
- """
40
- Checks if the specified directory has the expected number of files.
41
- """
42
- num_files = len([name for name in os.listdir(directory) if os.path.isfile(os.path.join(directory, name))])
43
- return num_files == expected_count
44
-
45
-
46
- @hydra.main(config_path="../config/data", config_name="download", version_base=None)
47
- def prepare_dataset(download_cfg):
48
- data_dir = download_cfg.path
49
- base_url = download_cfg.images.base_url
50
- datasets = download_cfg.images.datasets
51
-
52
- for dataset_type in datasets:
53
- file_name, expected_files = datasets[dataset_type].values()
54
- url = f"{base_url}{file_name}"
55
- local_zip_path = os.path.join(data_dir, file_name)
56
- extract_to = os.path.join(data_dir, dataset_type, "images")
57
-
58
- # Ensure the extraction directory exists
59
- os.makedirs(extract_to, exist_ok=True)
60
-
61
- # Check if the correct number of files exists
62
- if check_files(extract_to, expected_files):
63
- logger.info(f"✅ Dataset {dataset_type: >4} already verified.")
64
- continue
65
-
66
- if os.path.exists(local_zip_path):
67
- logger.info(f"Dataset {dataset_type} already downloaded.")
68
- else:
69
- download_file(url, local_zip_path)
70
-
71
- unzip_file(local_zip_path, extract_to)
72
-
73
- print(os.path.exists(local_zip_path), check_files(extract_to, expected_files))
74
-
75
- # Additional verification post extraction
76
- if not check_files(extract_to, expected_files):
77
- logger.error(f"Error in verifying the {dataset_type} dataset after extraction.")
78
-
79
-
80
- if __name__ == "__main__":
81
- from tools.log_helper import custom_logger
82
-
83
- custom_logger()
84
- prepare_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/loss.py DELETED
@@ -1,2 +0,0 @@
1
- def get_loss_function(*args, **kwargs):
2
- raise NotImplementedError