henry000 commited on
Commit
b11b504
Β·
2 Parent(s): 2218dc8 2007b83

πŸ”€ [Merge] branch 'DATASET' into TEST

Browse files
.gitignore CHANGED
@@ -111,8 +111,8 @@ dmypy.json
111
 
112
  # Machine learning specific folders and symlinks
113
  runs
114
- data
115
- datasets
116
 
117
  # Datasets and model checkpoints
118
  *.pth
 
111
 
112
  # Machine learning specific folders and symlinks
113
  runs
114
+ /data
115
+ /datasets
116
 
117
  # Datasets and model checkpoints
118
  *.pth
README.md CHANGED
@@ -3,7 +3,7 @@ An MIT license rewrite of YOLOv9
3
 
4
  ## To-Do Lists
5
  - [ ] Project Setup
6
- - [ ] requirements
7
  - [ ] LICENSE
8
  - [ ] README
9
  - [ ] pytests
 
3
 
4
  ## To-Do Lists
5
  - [ ] Project Setup
6
+ - [X] requirements
7
  - [ ] LICENSE
8
  - [ ] README
9
  - [ ] pytests
config/config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Dict, Union
3
+
4
+
5
+ @dataclass
6
+ class Model:
7
+ anchor: List[List[int]]
8
+ model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
9
+
10
+
11
+ @dataclass
12
+ class Download:
13
+ auto: bool
14
+ path: str
15
+
16
+
17
+ @dataclass
18
+ class Config:
19
+ model: Model
20
+ download: Download
config/config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ./runs
4
+
5
+ defaults:
6
+ - data: coco
7
+ - download: ../data/download
8
+ - augmentation: ../data/augmentation
9
+ - model: v7-base
10
+ - _self_
config/data/augmentation.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ RandomHorizontalFlip: 0.5
2
+ Mosaic: 0.5
config/data/coco.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ path: data/coco
config/data/download.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ auto: True
2
+ path: data/coco
3
+ images:
4
+ base_url: http://images.cocodataset.org/zips/
5
+ datasets:
6
+ train:
7
+ file_name: train2017.zip
8
+ file_num: 118287
9
+ val:
10
+ file_name: val2017.zip
11
+ num_files: 5000
12
+ test:
13
+ file_name: test2017.zip
14
+ num_files: 40670
15
+ hydra:
16
+ run:
17
+ dir: ./runs
model/yolo.py CHANGED
@@ -1,25 +1,10 @@
1
- import inspect
2
  from typing import Any, Dict, List, Union
3
 
4
  import torch
5
  import torch.nn as nn
6
  from loguru import logger
7
-
8
- from model import module
9
- from utils.tools import load_model_cfg
10
-
11
-
12
- def get_layer_map():
13
- """
14
- Dynamically generates a dictionary mapping class names to classes,
15
- filtering to include only those that are subclasses of nn.Module,
16
- ensuring they are relevant neural network layers.
17
- """
18
- layer_map = {}
19
- for name, obj in inspect.getmembers(module, inspect.isclass):
20
- if issubclass(obj, nn.Module) and obj is not nn.Module:
21
- layer_map[name] = obj
22
- return layer_map
23
 
24
 
25
  class YOLO(nn.Module):
@@ -35,16 +20,15 @@ class YOLO(nn.Module):
35
  super(YOLO, self).__init__()
36
  self.nc = model_cfg["nc"]
37
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
38
- self.build_model(model_cfg["model"])
39
 
40
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
41
  model_list = nn.ModuleList()
42
  output_dim = [3]
43
  layer_indices_by_tag = {}
44
-
45
- for arch_name, arch in model_arch.items():
46
  logger.info(f"πŸ—οΈ Building model-{arch_name}")
47
- for layer_idx, layer_spec in enumerate(arch, start=1):
48
  layer_type, layer_info = next(iter(layer_spec.items()))
49
  layer_args = layer_info.get("args", {})
50
  source = layer_info.get("source", -1)
@@ -74,7 +58,7 @@ class YOLO(nn.Module):
74
  y = [x]
75
  output = []
76
  for layer in self.model:
77
- if isinstance(layer.source, list):
78
  model_input = [y[idx] for idx in layer.source]
79
  else:
80
  model_input = y[layer.source]
@@ -113,6 +97,7 @@ def get_model(model_cfg: dict) -> YOLO:
113
  Returns:
114
  YOLO: An instance of the model defined by the given configuration.
115
  """
 
116
  model = YOLO(model_cfg)
117
  logger.info("βœ… Success load model")
118
  return model
 
 
1
  from typing import Any, Dict, List, Union
2
 
3
  import torch
4
  import torch.nn as nn
5
  from loguru import logger
6
+ from omegaconf import OmegaConf
7
+ from tools.layer_helper import get_layer_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  class YOLO(nn.Module):
 
20
  super(YOLO, self).__init__()
21
  self.nc = model_cfg["nc"]
22
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
23
+ self.build_model(model_cfg.model)
24
 
25
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
26
  model_list = nn.ModuleList()
27
  output_dim = [3]
28
  layer_indices_by_tag = {}
29
+ for arch_name in model_arch:
 
30
  logger.info(f"πŸ—οΈ Building model-{arch_name}")
31
+ for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=1):
32
  layer_type, layer_info = next(iter(layer_spec.items()))
33
  layer_args = layer_info.get("args", {})
34
  source = layer_info.get("source", -1)
 
58
  y = [x]
59
  output = []
60
  for layer in self.model:
61
+ if OmegaConf.is_list(layer.source):
62
  model_input = [y[idx] for idx in layer.source]
63
  else:
64
  model_input = y[layer.source]
 
97
  Returns:
98
  YOLO: An instance of the model defined by the given configuration.
99
  """
100
+ OmegaConf.set_struct(model_cfg, False)
101
  model = YOLO(model_cfg)
102
  logger.info("βœ… Success load model")
103
  return model
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ hydra-core
2
+ loguru
3
+ numpy
4
+ pytest
5
+ pyyaml
6
+ requests
7
+ rich
8
+ torch
9
+ tqdm
tools/layer_helper.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import torch.nn as nn
3
+ from model import module
4
+
5
+
6
+ def auto_pad():
7
+ raise NotImplementedError
8
+
9
+
10
+ def get_layer_map():
11
+ """
12
+ Dynamically generates a dictionary mapping class names to classes,
13
+ filtering to include only those that are subclasses of nn.Module,
14
+ ensuring they are relevant neural network layers.
15
+ """
16
+ layer_map = {}
17
+ for name, obj in inspect.getmembers(module, inspect.isclass):
18
+ if issubclass(obj, nn.Module) and obj is not nn.Module:
19
+ layer_map[name] = obj
20
+ return layer_map
tools/log_helper.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for initializing logging tools used in machine learning and data processing.
3
+ Supports integration with Weights & Biases (wandb), Loguru, TensorBoard, and other
4
+ logging frameworks as needed.
5
+
6
+ This setup ensures consistent logging across various platforms, facilitating
7
+ effective monitoring and debugging.
8
+
9
+ Example:
10
+ from tools.logger import custom_logger
11
+ custom_logger()
12
+ """
13
+
14
+ import sys
15
+ from loguru import logger
16
+
17
+
18
+ def custom_logger():
19
+ logger.remove()
20
+ logger.add(
21
+ sys.stderr,
22
+ format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
23
+ )
train.py CHANGED
@@ -1,26 +1,19 @@
1
- import argparse
2
  from loguru import logger
3
  from model.yolo import get_model
4
- from utils.tools import load_model_cfg, custom_logger
 
 
 
5
 
6
 
7
- def parse_arguments() -> argparse.Namespace:
8
- """
9
- Parse command-line arguments to get the model configuration file.
 
10
 
11
- Returns:
12
- argparse.Namespace: The command-line arguments object with 'config' attribute.
13
- """
14
- parser = argparse.ArgumentParser(description="Load a YOLO model configuration and display the model.")
15
- parser.add_argument(
16
- "--model-config", type=str, default="v7-base", help="Name or path to the model configuration file."
17
- )
18
- return parser.parse_args()
19
 
20
 
21
  if __name__ == "__main__":
22
  custom_logger()
23
- args = parse_arguments()
24
- model_cfg = load_model_cfg(args.model_config)
25
- model = get_model(model_cfg)
26
- logger.info("Success load model")
 
 
1
  from loguru import logger
2
  from model.yolo import get_model
3
+ from tools.log_helper import custom_logger
4
+ from utils.get_dataset import prepare_dataset
5
+ import hydra
6
+ from config.config import Config
7
 
8
 
9
+ @hydra.main(config_path="config", config_name="config", version_base=None)
10
+ def main(cfg: Config):
11
+ if cfg.download.auto:
12
+ prepare_dataset(cfg.download)
13
 
14
+ model = get_model(cfg.model)
 
 
 
 
 
 
 
15
 
16
 
17
  if __name__ == "__main__":
18
  custom_logger()
19
+ main()
 
 
 
utils/dataargument.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ from torchvision.transforms import functional as TF
5
+
6
+
7
+ class Compose:
8
+ """Composes several transforms together."""
9
+
10
+ def __init__(self, transforms):
11
+ self.transforms = transforms
12
+
13
+ for transform in self.transforms:
14
+ if hasattr(transform, "set_parent"):
15
+ transform.set_parent(self)
16
+
17
+ def __call__(self, image, boxes):
18
+ for transform in self.transforms:
19
+ image, boxes = transform(image, boxes)
20
+ return image, boxes
21
+
22
+ def get_more_data(self):
23
+ raise NotImplementedError("This method should be overridden by subclass instances!")
24
+
25
+
26
+ class RandomHorizontalFlip:
27
+ """Randomly horizontally flips the image along with the bounding boxes."""
28
+
29
+ def __init__(self, prob=0.5):
30
+ self.prob = prob
31
+
32
+ def __call__(self, image, boxes):
33
+ if torch.rand(1) < self.prob:
34
+ image = TF.hflip(image)
35
+ boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
36
+ return image, boxes
37
+
38
+
39
+ class Mosaic:
40
+ """Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""
41
+
42
+ def __init__(self, prob=0.5):
43
+ self.prob = prob
44
+ self.parent = None
45
+
46
+ def set_parent(self, parent):
47
+ self.parent = parent
48
+
49
+ def __call__(self, image, boxes):
50
+ if torch.rand(1) >= self.prob:
51
+ return image, boxes
52
+
53
+ assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
54
+
55
+ img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
56
+ more_data = self.parent.get_more_data(3) # get 3 more images randomly
57
+
58
+ data = [(image, boxes)] + more_data
59
+ mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz))
60
+ vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
61
+ center = np.array([img_sz, img_sz])
62
+ all_labels = []
63
+
64
+ for (image, boxes), vector in zip(data, vectors):
65
+ this_w, this_h = image.size
66
+ coord = tuple(center + vector * np.array([this_w, this_h]))
67
+
68
+ mosaic_image.paste(image, coord)
69
+ xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
70
+ xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
71
+ xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
72
+ ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
73
+ ymax = (ymax * this_h + coord[1]) / (2 * img_sz)
74
+
75
+ adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
76
+ all_labels.append(adjusted_boxes)
77
+
78
+ all_labels = torch.cat(all_labels, dim=0)
79
+ return mosaic_image, all_labels
utils/dataloader.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from os import path, listdir
3
+
4
+ import hydra
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from loguru import logger
9
+ from tqdm.rich import tqdm
10
+ import diskcache as dc
11
+ from typing import Union
12
+ from drawer import draw_bboxes
13
+ from dataargument import Compose, RandomHorizontalFlip, Mosaic
14
+
15
+
16
+ class YoloDataset(Dataset):
17
+ def __init__(self, dataset_cfg: dict, phase: str = "train", image_size: int = 640, transform=None):
18
+ phase_name = dataset_cfg.get(phase, phase)
19
+ self.image_size = image_size
20
+
21
+ self.transform = transform
22
+ self.transform.get_more_data = self.get_more_data
23
+ self.transform.image_size = self.image_size
24
+ self.data = self.load_data(dataset_cfg.path, phase_name)
25
+
26
+ def load_data(self, dataset_path, phase_name):
27
+ """
28
+ Loads data from a cache or generates a new cache for a specific dataset phase.
29
+
30
+ Parameters:
31
+ dataset_path (str): The root path to the dataset directory.
32
+ phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.
33
+
34
+ Returns:
35
+ dict: The loaded data from the cache for the specified phase.
36
+ """
37
+ cache_path = path.join(dataset_path, ".cache")
38
+ cache = dc.Cache(cache_path)
39
+ data = cache.get(phase_name)
40
+
41
+ if data is None:
42
+ logger.info("Generating {} cache", phase_name)
43
+ images_path = path.join(dataset_path, phase_name, "images")
44
+ labels_path = path.join(dataset_path, phase_name, "labels")
45
+ data = self.filter_data(images_path, labels_path)
46
+ cache[phase_name] = data
47
+
48
+ cache.close()
49
+ logger.info("Loaded {} cache", phase_name)
50
+ data = cache[phase_name]
51
+ return data
52
+
53
+ def filter_data(self, images_path: str, labels_path: str) -> list:
54
+ """
55
+ Filters and collects dataset information by pairing images with their corresponding labels.
56
+
57
+ Parameters:
58
+ images_path (str): Path to the directory containing image files.
59
+ labels_path (str): Path to the directory containing label files.
60
+
61
+ Returns:
62
+ list: A list of tuples, each containing the path to an image file and its associated labels as a tensor.
63
+ """
64
+ data = []
65
+ valid_inputs = 0
66
+ images_list = sorted(listdir(images_path))
67
+ for image_name in tqdm(images_list, desc="Filtering data"):
68
+ if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
69
+ continue
70
+
71
+ img_path = path.join(images_path, image_name)
72
+ base_name, _ = path.splitext(image_name)
73
+ label_path = path.join(labels_path, f"{base_name}.txt")
74
+
75
+ if path.isfile(label_path):
76
+ labels = self.load_valid_labels(label_path)
77
+ if labels is not None:
78
+ data.append((img_path, labels))
79
+ valid_inputs += 1
80
+
81
+ logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
82
+ return data
83
+
84
+ def load_valid_labels(self, label_path: str) -> Union[torch.Tensor, None]:
85
+ """
86
+ Loads and validates bounding box data is [0, 1] from a label file.
87
+
88
+ Parameters:
89
+ label_path (str): The filepath to the label file containing bounding box data.
90
+
91
+ Returns:
92
+ torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
93
+ """
94
+ bboxes = []
95
+ with open(label_path, "r") as file:
96
+ for line in file:
97
+ parts = list(map(float, line.strip().split()))
98
+ cls = parts[0]
99
+ points = np.array(parts[1:]).reshape(-1, 2)
100
+ valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
101
+ if valid_points.size > 1:
102
+ bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
103
+ bboxes.append(bbox)
104
+
105
+ if bboxes:
106
+ return torch.stack(bboxes)
107
+ else:
108
+ logger.warning("No valid BBox in {}", label_path)
109
+ return None
110
+
111
+ def get_data(self, idx):
112
+ img_path, bboxes = self.data[idx]
113
+ img = Image.open(img_path).convert("RGB")
114
+ return img, bboxes
115
+
116
+ def get_more_data(self, num: int = 1):
117
+ indices = torch.randint(0, len(self), (num,))
118
+ return [self.get_data(idx) for idx in indices]
119
+
120
+ def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
121
+ img, bboxes = self.get_data(idx)
122
+ if self.transform:
123
+ img, bboxes = self.transform(img, bboxes)
124
+ return img, bboxes
125
+
126
+ def __len__(self) -> int:
127
+ return len(self.data)
128
+
129
+
130
+ @hydra.main(config_path="../config", config_name="config", version_base=None)
131
+ def main(cfg):
132
+ transform = Compose([eval(aug)(prob) for aug, prob in cfg.augmentation.items()])
133
+ dataset = YoloDataset(cfg.data, transform=transform)
134
+ draw_bboxes(*dataset[0])
135
+
136
+
137
+ if __name__ == "__main__":
138
+ import sys
139
+
140
+ sys.path.append("./")
141
+ from tools.log_helper import custom_logger
142
+
143
+ custom_logger()
144
+ main()
utils/drawer.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageFont
2
+
3
+
4
+ def draw_bboxes(img, bboxes):
5
+ """
6
+ Draw bounding boxes on an image.
7
+
8
+ Args:
9
+ - image_path (str): Path to the image file.
10
+ - bboxes (list of lists/tuples): Bounding boxes with [x_min, y_min, x_max, y_max, class_id].
11
+ """
12
+ # Load an image
13
+ draw = ImageDraw.Draw(img)
14
+
15
+ # Font for class_id (optional)
16
+ try:
17
+ font = ImageFont.truetype("arial.ttf", 30)
18
+ except IOError:
19
+ font = ImageFont.load_default(30)
20
+ width, height = img.size
21
+
22
+ for bbox in bboxes:
23
+ class_id, x_min, y_min, x_max, y_max = bbox
24
+ x_min = x_min * width
25
+ x_max = x_max * width
26
+ y_min = y_min * height
27
+ y_max = y_max * height
28
+ shape = [(x_min, y_min), (x_max, y_max)]
29
+ draw.rectangle(shape, outline="red", width=2)
30
+ draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
31
+
32
+ img.save("output.jpg")
utils/get_dataset.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+
4
+ import hydra
5
+ from loguru import logger
6
+ import requests
7
+ from tqdm.rich import tqdm
8
+
9
+
10
+ def download_file(url, dest_path):
11
+ """
12
+ Downloads a file from a specified URL to a destination path with progress logging.
13
+ """
14
+ logger.info(f"Downloading {os.path.basename(dest_path)}...")
15
+ with requests.get(url, stream=True) as r:
16
+ r.raise_for_status()
17
+ total_length = int(r.headers.get("content-length", 0))
18
+ with open(dest_path, "wb") as f, tqdm(
19
+ total=total_length, unit="iB", unit_scale=True, desc=os.path.basename(dest_path), leave=True
20
+ ) as bar:
21
+ for chunk in r.iter_content(chunk_size=1024 * 1024):
22
+ f.write(chunk)
23
+ bar.update(len(chunk))
24
+ logger.info("Download complete!")
25
+
26
+
27
+ def unzip_file(zip_path, extract_to):
28
+ """
29
+ Unzips a ZIP file to a specified directory.
30
+ """
31
+ logger.info(f"Unzipping {os.path.basename(zip_path)}...")
32
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
33
+ zip_ref.extractall(extract_to)
34
+ os.remove(zip_path)
35
+ logger.info(f"Removed {zip_path}")
36
+
37
+
38
+ def check_files(directory, expected_count):
39
+ """
40
+ Checks if the specified directory has the expected number of files.
41
+ """
42
+ num_files = len([name for name in os.listdir(directory) if os.path.isfile(os.path.join(directory, name))])
43
+ return num_files == expected_count
44
+
45
+
46
+ @hydra.main(config_path="../config/data", config_name="download", version_base=None)
47
+ def prepare_dataset(download_cfg):
48
+ data_dir = download_cfg.path
49
+ base_url = download_cfg.images.base_url
50
+ datasets = download_cfg.images.datasets
51
+
52
+ for dataset_type in datasets:
53
+ file_name, expected_files = datasets[dataset_type].values()
54
+ url = f"{base_url}{file_name}"
55
+ local_zip_path = os.path.join(data_dir, file_name)
56
+ extract_to = os.path.join(data_dir, dataset_type, "images")
57
+
58
+ # Ensure the extraction directory exists
59
+ os.makedirs(extract_to, exist_ok=True)
60
+
61
+ # Check if the correct number of files exists
62
+ if check_files(extract_to, expected_files):
63
+ logger.info(f"Dataset {dataset_type} already verified.")
64
+ continue
65
+
66
+ if os.path.exists(local_zip_path):
67
+ logger.info(f"Dataset {dataset_type} already downloaded.")
68
+ else:
69
+ download_file(url, local_zip_path)
70
+
71
+ unzip_file(local_zip_path, extract_to)
72
+
73
+ print(os.path.exists(local_zip_path), check_files(extract_to, expected_files))
74
+
75
+ # Additional verification post extraction
76
+ if not check_files(extract_to, expected_files):
77
+ logger.error(f"Error in verifying the {dataset_type} dataset after extraction.")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ from tools.log_helper import custom_logger
82
+
83
+ custom_logger()
84
+ prepare_dataset()
utils/tools.py DELETED
@@ -1,72 +0,0 @@
1
- import os
2
- import sys
3
- import yaml
4
- from loguru import logger
5
- from typing import Dict, Any
6
-
7
-
8
- def complete_path(file_name: str = "v7-base.yaml") -> str:
9
- """
10
- Ensures the path to a model configuration is a existing file
11
-
12
- Parameters:
13
- file_name (str): The filename or path, with default 'v7-base.yaml'.
14
-
15
- Returns:
16
- str: A complete path with necessary prefix and extension.
17
- """
18
- # Ensure the file has the '.yaml' extension if missing
19
- if not file_name.endswith(".yaml"):
20
- file_name += ".yaml"
21
-
22
- # Add folder prefix if only the filename is provided
23
- if os.path.dirname(file_name) == "":
24
- file_name = os.path.join("./config/model", file_name)
25
-
26
- return file_name
27
-
28
-
29
- def load_model_cfg(file_path: str) -> Dict[str, Any]:
30
- """
31
- Read a YAML configuration file, ensure necessary keys are present, and return its content as a dictionary.
32
-
33
- Args:
34
- file_path (str): The path to the YAML configuration file.
35
-
36
- Returns:
37
- Dict[str, Any]: The contents of the YAML file as a dictionary.
38
-
39
- Raises:
40
- FileNotFoundError: If the YAML file cannot be found.
41
- yaml.YAMLError: If there is an error parsing the YAML file.
42
- """
43
- file_path = complete_path(file_path)
44
- try:
45
- with open(file_path, "r") as file:
46
- model_cfg = yaml.safe_load(file) or {}
47
-
48
- # Check for required keys and set defaults if not present
49
- if "nc" not in model_cfg:
50
- model_cfg["nc"] = 80
51
- logger.warning("'nc' not found in the YAML file. Setting default 'nc' to 80.")
52
-
53
- if "model" not in model_cfg:
54
- logger.error("'model' is missing in the configuration file.")
55
- raise ValueError("Missing required key: 'model'")
56
-
57
- return model_cfg
58
-
59
- except FileNotFoundError:
60
- logger.error(f"YAML file not found: {file_path}")
61
- raise
62
- except yaml.YAMLError as e:
63
- logger.error(f"Error parsing YAML file: {e}")
64
- raise
65
-
66
-
67
- def custom_logger():
68
- logger.remove()
69
- logger.add(
70
- sys.stderr,
71
- format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
72
- )