henry000 commited on
Commit
b44d6bb
Β·
2 Parent(s): d24904a 1197f7d

πŸ”€ [Merge] branch 'SETUP' into MODEL

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Kin-Yiu, Wong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -20,31 +20,31 @@ If you are interested in contributing, please keep an eye on project updates or
20
  ## To-Do Lists
21
  - [ ] Project Setup
22
  - [X] requirements
23
- - [ ] LICENSE
24
  - [ ] README
25
- - [ ] pytests
26
  - [ ] setup.py/pip install
27
- - [ ] log format
28
  - [ ] hugging face
29
  - [ ] Data proccess
30
  - [ ] Dataset
31
- - [ ] Download script
32
  - [ ] Auto Download
33
  - [ ] xywh, xxyy, xcyc
34
- - [ ] Dataloder
35
- - [ ] Data augment
36
  - [ ] Model
37
  - [ ] load model
38
  - [ ] from yaml
39
  - [ ] from github
40
- - [ ] trainer
41
- - [ ] train_one_iter
42
- - [ ] train_one_epoch
43
- - [ ] DDP, EMA, OTA
 
 
44
  - [ ] Run
45
  - [ ] train
46
  - [ ] test
47
  - [ ] demo
48
- - [ ] Configuration
49
- - [ ] hyperparams: dataclass
50
- - [ ] model cfg: yaml
 
20
  ## To-Do Lists
21
  - [ ] Project Setup
22
  - [X] requirements
23
+ - [x] LICENSE
24
  - [ ] README
25
+ - [x] pytests
26
  - [ ] setup.py/pip install
27
+ - [x] log format
28
  - [ ] hugging face
29
  - [ ] Data proccess
30
  - [ ] Dataset
31
+ - [x] Download script
32
  - [ ] Auto Download
33
  - [ ] xywh, xxyy, xcyc
34
+ - [x] Dataloder
35
+ - [x] Data arugment
36
  - [ ] Model
37
  - [ ] load model
38
  - [ ] from yaml
39
  - [ ] from github
40
+ - [x] trainer
41
+ - [x] train_one_iter
42
+ - [x] train_one_epoch
43
+ - [ ] DDP
44
+ - [x] EMA, OTA
45
+ - [ ] Loss
46
  - [ ] Run
47
  - [ ] train
48
  - [ ] test
49
  - [ ] demo
50
+ - [x] Configuration
 
 
examples/example_train.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import hydra
5
+ import torch
6
+ from loguru import logger
7
+
8
+ project_root = Path(__file__).resolve().parent.parent
9
+ sys.path.append(str(project_root))
10
+
11
+ from yolo.config.config import Config
12
+ from yolo.model.yolo import get_model
13
+ from yolo.tools.log_helper import custom_logger
14
+ from yolo.tools.trainer import Trainer
15
+ from yolo.utils.dataloader import get_dataloader
16
+ from yolo.utils.get_dataset import prepare_dataset
17
+
18
+
19
+ @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
20
+ def main(cfg: Config):
21
+ if cfg.download.auto:
22
+ prepare_dataset(cfg.download)
23
+
24
+ dataloader = get_dataloader(cfg)
25
+ model = get_model(cfg.model)
26
+ # TODO: get_device or rank, for DDP mode
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ trainer = Trainer(model, cfg.hyper.train, device)
30
+ trainer.train(dataloader, 10)
31
+
32
+
33
+ if __name__ == "__main__":
34
+ custom_logger()
35
+ main()
{config β†’ yolo/config}/README.md RENAMED
File without changes
yolo/config/config.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Union
3
+
4
+
5
+ @dataclass
6
+ class Model:
7
+ anchor: List[List[int]]
8
+ model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
9
+
10
+
11
+ @dataclass
12
+ class Download:
13
+ auto: bool
14
+ path: str
15
+
16
+
17
+ @dataclass
18
+ class DataLoaderConfig:
19
+ batch_size: int
20
+ shuffle: bool
21
+ num_workers: int
22
+ pin_memory: bool
23
+
24
+
25
+ @dataclass
26
+ class OptimizerArgs:
27
+ lr: float
28
+ weight_decay: float
29
+
30
+
31
+ @dataclass
32
+ class OptimizerConfig:
33
+ type: str
34
+ args: OptimizerArgs
35
+
36
+
37
+ @dataclass
38
+ class SchedulerArgs:
39
+ step_size: int
40
+ gamma: float
41
+
42
+
43
+ @dataclass
44
+ class SchedulerConfig:
45
+ type: str
46
+ args: SchedulerArgs
47
+
48
+
49
+ @dataclass
50
+ class EMAConfig:
51
+ enabled: bool
52
+ decay: float
53
+
54
+
55
+ @dataclass
56
+ class TrainConfig:
57
+ optimizer: OptimizerConfig
58
+ scheduler: SchedulerConfig
59
+ ema: EMAConfig
60
+
61
+
62
+ @dataclass
63
+ class HyperConfig:
64
+ data: DataLoaderConfig
65
+ train: TrainConfig
66
+
67
+
68
+ @dataclass
69
+ class Dataset:
70
+ file_name: str
71
+ num_files: int
72
+
73
+
74
+ @dataclass
75
+ class Datasets:
76
+ base_url: str
77
+ images: Dict[str, Dataset]
78
+
79
+
80
+ @dataclass
81
+ class Download:
82
+ auto: bool
83
+ save_path: str
84
+ datasets: Datasets
85
+
86
+
87
+ @dataclass
88
+ class Config:
89
+ model: Model
90
+ download: Download
91
+ hyper: HyperConfig
yolo/config/config.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ./runs
4
+
5
+ defaults:
6
+ - data: coco
7
+ - download: ../data/download
8
+ - augmentation: ../data/augmentation
9
+ - model: v7-base
10
+ - hyper: default
11
+ - _self_
yolo/config/data/augmentation.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Mosaic: 1
2
+ # MixUp: 1
3
+ HorizontalFlip: 0.5
{config β†’ yolo/config}/data/coco.yaml RENAMED
File without changes
yolo/config/data/download.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ auto: True
2
+ save_path: data/coco
3
+ datasets:
4
+ images:
5
+ base_url: http://images.cocodataset.org/zips/
6
+ train2017:
7
+ file_name: train2017
8
+ file_num: 118287
9
+ val2017:
10
+ file_name: val2017
11
+ file_num: 5000
12
+ test2017:
13
+ file_name: test2017
14
+ file_num: 40670
15
+ annotations:
16
+ base_url: http://images.cocodataset.org/annotations/
17
+ annotations:
18
+ file_name: annotations_trainval2017
19
+ hydra:
20
+ run:
21
+ dir: ./runs
yolo/config/hyper/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ batch_size: 4
3
+ shuffle: True
4
+ num_workers: 4
5
+ pin_memory: True
6
+ train:
7
+ optimizer:
8
+ type: Adam
9
+ args:
10
+ lr: 0.001
11
+ weight_decay: 0.0001
12
+ scheduler:
13
+ type: StepLR
14
+ args:
15
+ step_size: 10
16
+ gamma: 0.1
17
+ ema:
18
+ enabled: true
19
+ decay: 0.995
{config β†’ yolo/config}/model/v7-base.yaml RENAMED
File without changes
{model β†’ yolo/model}/README.md RENAMED
File without changes
{model β†’ yolo/model}/module.py RENAMED
File without changes
{model β†’ yolo/model}/yolo.py RENAMED
@@ -5,7 +5,7 @@ import torch.nn as nn
5
  from loguru import logger
6
  from omegaconf import OmegaConf
7
 
8
- from tools.layer_helper import get_layer_map
9
 
10
 
11
  class YOLO(nn.Module):
 
5
  from loguru import logger
6
  from omegaconf import OmegaConf
7
 
8
+ from yolo.tools.layer_helper import get_layer_map
9
 
10
 
11
  class YOLO(nn.Module):
yolo/tools/__init__.py ADDED
File without changes
yolo/tools/dataset_helper.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from itertools import chain
4
+ from os import path
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import numpy as np
8
+
9
+
10
+ def find_labels_path(dataset_path: str, phase_name: str):
11
+ """
12
+ Find the path to label files for a specified dataset and phase(e.g. training).
13
+
14
+ Args:
15
+ dataset_path (str): The path to the root directory of the dataset.
16
+ phase_name (str): The name of the phase for which labels are being searched (e.g., "train", "val", "test").
17
+
18
+ Returns:
19
+ Tuple[str, str]: A tuple containing the path to the labels file and the file format ("json" or "txt").
20
+ """
21
+ json_labels_path = path.join(dataset_path, "annotations", f"instances_{phase_name}.json")
22
+
23
+ txt_labels_path = path.join(dataset_path, "label", phase_name)
24
+
25
+ if path.isfile(json_labels_path):
26
+ return json_labels_path, "json"
27
+
28
+ elif path.isdir(txt_labels_path):
29
+ txt_files = [f for f in os.listdir(txt_labels_path) if f.endswith(".txt")]
30
+ if txt_files:
31
+ return txt_labels_path, "txt"
32
+
33
+ raise FileNotFoundError("No labels found in the specified dataset path and phase name.")
34
+
35
+
36
+ def create_image_info_dict(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]:
37
+ """
38
+ Create a dictionary containing image information and annotations indexed by image ID.
39
+
40
+ Args:
41
+ labels_path (str): The path to the annotation json file.
42
+
43
+ Returns:
44
+ - annotations_index: A dictionary where keys are image IDs and values are lists of annotations.
45
+ - image_info_dict: A dictionary where keys are image file names without extension and values are image information dictionaries.
46
+ """
47
+ with open(labels_path, "r") as file:
48
+ labels_data = json.load(file)
49
+ annotations_index = index_annotations_by_image(labels_data) # check lookup is a good name?
50
+ image_info_dict = {path.splitext(img["file_name"])[0]: img for img in labels_data["images"]}
51
+ return annotations_index, image_info_dict
52
+
53
+
54
+ def index_annotations_by_image(data: Dict[str, Any]):
55
+ """
56
+ Use image index to lookup every annotations
57
+ Args:
58
+ data (Dict[str, Any]): A dictionary containing annotation data.
59
+
60
+ Returns:
61
+ Dict[int, List[Dict[str, Any]]]: A dictionary where keys are image IDs and values are lists of annotations.
62
+ Annotations with "iscrowd" set to True are excluded from the index.
63
+
64
+ """
65
+ annotation_lookup = {}
66
+ for anno in data["annotations"]:
67
+ if anno["iscrowd"]:
68
+ continue
69
+ image_id = anno["image_id"]
70
+ if image_id not in annotation_lookup:
71
+ annotation_lookup[image_id] = []
72
+ annotation_lookup[image_id].append(anno)
73
+ return annotation_lookup
74
+
75
+
76
+ def get_scaled_segmentation(
77
+ annotations: List[Dict[str, Any]], image_dimensions: Dict[str, int]
78
+ ) -> Optional[List[List[float]]]:
79
+ """
80
+ Scale the segmentation data based on image dimensions and return a list of scaled segmentation data.
81
+
82
+ Args:
83
+ annotations (List[Dict[str, Any]]): A list of annotation dictionaries.
84
+ image_dimensions (Dict[str, int]): A dictionary containing image dimensions (height and width).
85
+
86
+ Returns:
87
+ Optional[List[List[float]]]: A list of scaled segmentation data, where each sublist contains category_id followed by scaled (x, y) coordinates.
88
+ """
89
+ if annotations is None:
90
+ return None
91
+
92
+ seg_array_with_cat = []
93
+ h, w = image_dimensions["height"], image_dimensions["width"]
94
+ for anno in annotations:
95
+ category_id = anno["category_id"]
96
+ seg_list = [item for sublist in anno["segmentation"] for item in sublist]
97
+ scaled_seg_data = (
98
+ np.array(seg_list).reshape(-1, 2) / [w, h]
99
+ ).tolist() # make the list group in x, y pairs and scaled with image width, height
100
+ scaled_flat_seg_data = [category_id] + list(chain(*scaled_seg_data)) # flatten the scaled_seg_data list
101
+ seg_array_with_cat.append(scaled_flat_seg_data)
102
+
103
+ return seg_array_with_cat
{tools β†’ yolo/tools}/layer_helper.py RENAMED
@@ -2,7 +2,7 @@ import inspect
2
 
3
  import torch.nn as nn
4
 
5
- from model import module
6
 
7
 
8
  def auto_pad():
 
2
 
3
  import torch.nn as nn
4
 
5
+ from yolo.model import module
6
 
7
 
8
  def auto_pad():
{tools β†’ yolo/tools}/log_helper.py RENAMED
File without changes
{tools β†’ yolo/tools}/model_helper.py RENAMED
@@ -4,7 +4,7 @@ import torch
4
  from torch.optim import Optimizer
5
  from torch.optim.lr_scheduler import _LRScheduler
6
 
7
- from config.config import OptimizerConfig, SchedulerConfig
8
 
9
 
10
  class EMA:
 
4
  from torch.optim import Optimizer
5
  from torch.optim.lr_scheduler import _LRScheduler
6
 
7
+ from yolo.config.config import OptimizerConfig, SchedulerConfig
8
 
9
 
10
  class EMA:
{tools β†’ yolo/tools}/trainer.py RENAMED
@@ -2,10 +2,10 @@ import torch
2
  from loguru import logger
3
  from tqdm import tqdm
4
 
5
- from config.config import TrainConfig
6
- from model.yolo import YOLO
7
- from tools.model_helper import EMA, get_optimizer, get_scheduler
8
- from utils.loss import get_loss_function
9
 
10
 
11
  class Trainer:
 
2
  from loguru import logger
3
  from tqdm import tqdm
4
 
5
+ from yolo.config.config import TrainConfig
6
+ from yolo.model.yolo import YOLO
7
+ from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
8
+ from yolo.utils.loss import get_loss_function
9
 
10
 
11
  class Trainer:
{utils β†’ yolo/utils}/README.md RENAMED
File without changes
yolo/utils/converter_json2txt.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Dict, List, Optional
4
+
5
+ from tqdm import tqdm
6
+
7
+
8
+ def discretize_categories(categories: List[Dict[str, int]]) -> Dict[int, int]:
9
+ """
10
+ Maps each unique 'id' in the list of category dictionaries to a sequential integer index.
11
+ Indices are assigned based on the sorted 'id' values.
12
+ """
13
+ sorted_categories = sorted(categories, key=lambda category: category["id"])
14
+ return {category["id"]: index for index, category in enumerate(sorted_categories)}
15
+
16
+
17
+ def process_annotations(
18
+ image_annotations: Dict[int, List[Dict]],
19
+ image_info_dict: Dict[int, tuple],
20
+ output_dir: str,
21
+ id_to_idx: Optional[Dict[int, int]] = None,
22
+ ) -> None:
23
+ """
24
+ Process and save annotations to files, with option to remap category IDs.
25
+ """
26
+ for image_id, annotations in tqdm(image_annotations.items(), desc="Processing annotations"):
27
+ file_path = os.path.join(output_dir, f"{image_id:0>12}.txt")
28
+ if not annotations:
29
+ continue
30
+ with open(file_path, "w") as file:
31
+ for annotation in annotations:
32
+ process_annotation(annotation, image_info_dict[image_id], id_to_idx, file)
33
+
34
+
35
+ def process_annotation(annotation: Dict, image_dims: tuple, id_to_idx: Optional[Dict[int, int]], file) -> None:
36
+ """
37
+ Convert a single annotation's segmentation and write it to the open file handle.
38
+ """
39
+ category_id = annotation["category_id"]
40
+ segmentation = (
41
+ annotation["segmentation"][0]
42
+ if annotation["segmentation"] and isinstance(annotation["segmentation"][0], list)
43
+ else None
44
+ )
45
+
46
+ if segmentation is None:
47
+ return
48
+
49
+ img_width, img_height = image_dims
50
+ normalized_segmentation = normalize_segmentation(segmentation, img_width, img_height)
51
+
52
+ if id_to_idx:
53
+ category_id = id_to_idx.get(category_id, category_id)
54
+
55
+ file.write(f"{category_id} {' '.join(normalized_segmentation)}\n")
56
+
57
+
58
+ def normalize_segmentation(segmentation: List[float], img_width: int, img_height: int) -> List[str]:
59
+ """
60
+ Normalize and format segmentation coordinates.
61
+ """
62
+ normalized = [
63
+ f"{coord / img_width:.6f}" if index % 2 == 0 else f"{coord / img_height:.6f}"
64
+ for index, coord in enumerate(segmentation)
65
+ ]
66
+ return normalized
67
+
68
+
69
+ def convert_annotations(json_file: str, output_dir: str) -> None:
70
+ """
71
+ Load annotation data from a JSON file and process all annotations.
72
+ """
73
+ with open(json_file) as file:
74
+ data = json.load(file)
75
+
76
+ os.makedirs(output_dir, exist_ok=True)
77
+
78
+ image_info_dict = {img["id"]: (img["width"], img["height"]) for img in data.get("images", [])}
79
+ id_to_idx = discretize_categories(data.get("categories", [])) if "categories" in data else None
80
+ image_annotations = {img_id: [] for img_id in image_info_dict}
81
+
82
+ for annotation in data.get("annotations", []):
83
+ if not annotation.get("iscrowd", False):
84
+ image_annotations[annotation["image_id"]].append(annotation)
85
+
86
+ process_annotations(image_annotations, image_info_dict, output_dir, id_to_idx)
87
+
88
+
89
+ convert_annotations("./data/coco/annotations/instances_train2017.json", "./data/coco/labels/train2017/")
90
+ convert_annotations("./data/coco/annotations/instances_val2017.json", "./data/coco/labels/val2017/")
yolo/utils/data_augment.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision.transforms import functional as TF
5
+
6
+
7
+ class Compose:
8
+ """Composes several transforms together."""
9
+
10
+ def __init__(self, transforms, image_size: int = 640):
11
+ self.transforms = transforms
12
+ self.image_size = image_size
13
+
14
+ for transform in self.transforms:
15
+ if hasattr(transform, "set_parent"):
16
+ transform.set_parent(self)
17
+
18
+ def __call__(self, image, boxes):
19
+ for transform in self.transforms:
20
+ image, boxes = transform(image, boxes)
21
+ return image, boxes
22
+
23
+
24
+ class HorizontalFlip:
25
+ """Randomly horizontally flips the image along with the bounding boxes."""
26
+
27
+ def __init__(self, prob=0.5):
28
+ self.prob = prob
29
+
30
+ def __call__(self, image, boxes):
31
+ if torch.rand(1) < self.prob:
32
+ image = TF.hflip(image)
33
+ boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
34
+ return image, boxes
35
+
36
+
37
+ class VerticalFlip:
38
+ """Randomly vertically flips the image along with the bounding boxes."""
39
+
40
+ def __init__(self, prob=0.5):
41
+ self.prob = prob
42
+
43
+ def __call__(self, image, boxes):
44
+ if torch.rand(1) < self.prob:
45
+ image = TF.vflip(image)
46
+ boxes[:, [2, 4]] = 1 - boxes[:, [4, 2]]
47
+ return image, boxes
48
+
49
+
50
+ class Mosaic:
51
+ """Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""
52
+
53
+ def __init__(self, prob=0.5):
54
+ self.prob = prob
55
+ self.parent = None
56
+
57
+ def set_parent(self, parent):
58
+ self.parent = parent
59
+
60
+ def __call__(self, image, boxes):
61
+ if torch.rand(1) >= self.prob:
62
+ return image, boxes
63
+
64
+ assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
65
+
66
+ img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
67
+ more_data = self.parent.get_more_data(3) # get 3 more images randomly
68
+
69
+ data = [(image, boxes)] + more_data
70
+ mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz))
71
+ vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
72
+ center = np.array([img_sz, img_sz])
73
+ all_labels = []
74
+
75
+ for (image, boxes), vector in zip(data, vectors):
76
+ this_w, this_h = image.size
77
+ coord = tuple(center + vector * np.array([this_w, this_h]))
78
+
79
+ mosaic_image.paste(image, coord)
80
+ xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
81
+ xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
82
+ xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
83
+ ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
84
+ ymax = (ymax * this_h + coord[1]) / (2 * img_sz)
85
+
86
+ adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
87
+ all_labels.append(adjusted_boxes)
88
+
89
+ all_labels = torch.cat(all_labels, dim=0)
90
+ mosaic_image = mosaic_image.resize((img_sz, img_sz))
91
+ return mosaic_image, all_labels
92
+
93
+
94
+ class MixUp:
95
+ """Applies the MixUp augmentation to a pair of images and their corresponding boxes."""
96
+
97
+ def __init__(self, prob=0.5, alpha=1.0):
98
+ self.alpha = alpha
99
+ self.prob = prob
100
+ self.parent = None
101
+
102
+ def set_parent(self, parent):
103
+ """Set the parent dataset object for accessing dataset methods."""
104
+ self.parent = parent
105
+
106
+ def __call__(self, image, boxes):
107
+ if torch.rand(1) >= self.prob:
108
+ return image, boxes
109
+
110
+ assert self.parent is not None, "Parent is not set. MixUp cannot retrieve additional data."
111
+
112
+ # Retrieve another image and its boxes randomly from the dataset
113
+ image2, boxes2 = self.parent.get_more_data()[0]
114
+
115
+ # Calculate the mixup lambda parameter
116
+ lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
117
+
118
+ # Mix images
119
+ image1, image2 = TF.to_tensor(image), TF.to_tensor(image2)
120
+ mixed_image = lam * image1 + (1 - lam) * image2
121
+
122
+ # Mix bounding boxes
123
+ mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
124
+
125
+ return TF.to_pil_image(mixed_image), mixed_boxes
yolo/utils/dataloader.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os import path
3
+ from typing import List, Tuple, Union
4
+
5
+ import diskcache as dc
6
+ import hydra
7
+ import numpy as np
8
+ import torch
9
+ from loguru import logger
10
+ from PIL import Image
11
+ from torch.utils.data import DataLoader, Dataset
12
+ from torchvision.transforms import functional as TF
13
+ from tqdm.rich import tqdm
14
+
15
+ from yolo.tools.dataset_helper import (
16
+ create_image_info_dict,
17
+ find_labels_path,
18
+ get_scaled_segmentation,
19
+ )
20
+ from yolo.utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
21
+ from yolo.utils.drawer import draw_bboxes
22
+
23
+
24
+ class YoloDataset(Dataset):
25
+ def __init__(self, config: dict, phase: str = "train2017", image_size: int = 640):
26
+ dataset_cfg = config.data
27
+ augment_cfg = config.augmentation
28
+ phase_name = dataset_cfg.get(phase, phase)
29
+ self.image_size = image_size
30
+
31
+ transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
32
+ self.transform = Compose(transforms, self.image_size)
33
+ self.transform.get_more_data = self.get_more_data
34
+ self.data = self.load_data(dataset_cfg.path, phase_name)
35
+
36
+ def load_data(self, dataset_path, phase_name):
37
+ """
38
+ Loads data from a cache or generates a new cache for a specific dataset phase.
39
+
40
+ Parameters:
41
+ dataset_path (str): The root path to the dataset directory.
42
+ phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.
43
+
44
+ Returns:
45
+ dict: The loaded data from the cache for the specified phase.
46
+ """
47
+ cache_path = path.join(dataset_path, ".cache")
48
+ cache = dc.Cache(cache_path)
49
+ data = cache.get(phase_name)
50
+
51
+ if data is None:
52
+ logger.info("Generating {} cache", phase_name)
53
+ data = self.filter_data(dataset_path, phase_name)
54
+ cache[phase_name] = data
55
+
56
+ cache.close()
57
+ logger.info("πŸ“¦ Loaded {} cache", phase_name)
58
+ data = cache[phase_name]
59
+ return data
60
+
61
+ def filter_data(self, dataset_path: str, phase_name: str) -> list:
62
+ """
63
+ Filters and collects dataset information by pairing images with their corresponding labels.
64
+
65
+ Parameters:
66
+ images_path (str): Path to the directory containing image files.
67
+ labels_path (str): Path to the directory containing label files.
68
+
69
+ Returns:
70
+ list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
71
+ """
72
+ images_path = path.join(dataset_path, "images", phase_name)
73
+ labels_path, data_type = find_labels_path(dataset_path, phase_name)
74
+ images_list = sorted(os.listdir(images_path))
75
+ if data_type == "json":
76
+ annotations_index, image_info_dict = create_image_info_dict(labels_path)
77
+
78
+ data = []
79
+ valid_inputs = 0
80
+ for image_name in tqdm(images_list, desc="Filtering data"):
81
+ if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
82
+ continue
83
+ image_id, _ = path.splitext(image_name)
84
+
85
+ if data_type == "json":
86
+ image_info = image_info_dict.get(image_id, None)
87
+ if image_info is None:
88
+ continue
89
+ annotations = annotations_index.get(image_info["id"], [])
90
+ image_seg_annotations = get_scaled_segmentation(annotations, image_info)
91
+ if not image_seg_annotations:
92
+ continue
93
+
94
+ elif data_type == "txt":
95
+ label_path = path.join(labels_path, f"{image_id}.txt")
96
+ if not path.isfile(label_path):
97
+ continue
98
+ with open(label_path, "r") as file:
99
+ image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
100
+
101
+ labels = self.load_valid_labels(image_id, image_seg_annotations)
102
+ if labels is not None:
103
+ img_path = path.join(images_path, image_name)
104
+ data.append((img_path, labels))
105
+ valid_inputs += 1
106
+
107
+ logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
108
+ return data
109
+
110
+ def load_valid_labels(self, label_path, seg_data_one_img) -> Union[torch.Tensor, None]:
111
+ """
112
+ Loads and validates bounding box data is [0, 1] from a label file.
113
+
114
+ Parameters:
115
+ label_path (str): The filepath to the label file containing bounding box data.
116
+
117
+ Returns:
118
+ torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
119
+ """
120
+ bboxes = []
121
+ for seg_data in seg_data_one_img:
122
+ cls = seg_data[0]
123
+ points = np.array(seg_data[1:]).reshape(-1, 2)
124
+ valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
125
+ if valid_points.size > 1:
126
+ bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
127
+ bboxes.append(bbox)
128
+
129
+ if bboxes:
130
+ return torch.stack(bboxes)
131
+ else:
132
+ logger.warning("No valid BBox in {}", label_path)
133
+ return None
134
+
135
+ def get_data(self, idx):
136
+ img_path, bboxes = self.data[idx]
137
+ img = Image.open(img_path).convert("RGB")
138
+ return img, bboxes
139
+
140
+ def get_more_data(self, num: int = 1):
141
+ indices = torch.randint(0, len(self), (num,))
142
+ return [self.get_data(idx) for idx in indices]
143
+
144
+ def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
145
+ img, bboxes = self.get_data(idx)
146
+ if self.transform:
147
+ img, bboxes = self.transform(img, bboxes)
148
+ img = TF.to_tensor(img)
149
+ return img, bboxes
150
+
151
+ def __len__(self) -> int:
152
+ return len(self.data)
153
+
154
+
155
+ class YoloDataLoader(DataLoader):
156
+ def __init__(self, config: dict):
157
+ """Initializes the YoloDataLoader with hydra-config files."""
158
+ hyper = config.hyper.data
159
+ dataset = YoloDataset(config)
160
+
161
+ super().__init__(
162
+ dataset,
163
+ batch_size=hyper.batch_size,
164
+ shuffle=hyper.shuffle,
165
+ num_workers=hyper.num_workers,
166
+ pin_memory=hyper.pin_memory,
167
+ collate_fn=self.collate_fn,
168
+ )
169
+
170
+ def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
171
+ """
172
+ A collate function to handle batching of images and their corresponding targets.
173
+
174
+ Args:
175
+ batch (list of tuples): Each tuple contains:
176
+ - image (torch.Tensor): The image tensor.
177
+ - labels (torch.Tensor): The tensor of labels for the image.
178
+
179
+ Returns:
180
+ Tuple[torch.Tensor, List[torch.Tensor]]: A tuple containing:
181
+ - A tensor of batched images.
182
+ - A list of tensors, each corresponding to bboxes for each image in the batch.
183
+ """
184
+ images = torch.stack([item[0] for item in batch])
185
+ targets = [item[1] for item in batch]
186
+ return images, targets
187
+
188
+
189
+ def get_dataloader(config):
190
+ return YoloDataLoader(config)
191
+
192
+
193
+ @hydra.main(config_path="../config", config_name="config", version_base=None)
194
+ def main(cfg):
195
+ dataloader = get_dataloader(cfg)
196
+ draw_bboxes(*next(iter(dataloader)))
197
+
198
+
199
+ if __name__ == "__main__":
200
+ import sys
201
+
202
+ sys.path.append("./")
203
+ from tools.log_helper import custom_logger
204
+
205
+ custom_logger()
206
+ main()
yolo/utils/drawer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ from loguru import logger
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ from torchvision.transforms.functional import to_pil_image
7
+
8
+
9
+ def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
10
+ """
11
+ Draw bounding boxes on an image.
12
+
13
+ Args:
14
+ - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
15
+ - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
16
+ where coordinates are normalized [0, 1].
17
+ """
18
+ # Convert tensor image to PIL Image if necessary
19
+ if isinstance(img, torch.Tensor):
20
+ if img.dim() > 3:
21
+ logger.info("Multi-frame tensor detected, using the first image.")
22
+ img = img[0]
23
+ bboxes = bboxes[0]
24
+ img = to_pil_image(img)
25
+
26
+ draw = ImageDraw.Draw(img)
27
+ width, height = img.size
28
+ font = ImageFont.load_default(30)
29
+
30
+ for bbox in bboxes:
31
+ class_id, x_min, y_min, x_max, y_max = bbox
32
+ x_min = x_min * width
33
+ x_max = x_max * width
34
+ y_min = y_min * height
35
+ y_max = y_max * height
36
+ shape = [(x_min, y_min), (x_max, y_max)]
37
+ draw.rectangle(shape, outline="red", width=3)
38
+ draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
39
+
40
+ img.save("visualize.jpg") # Save the image with annotations
41
+ logger.info("Saved visualize image at visualize.png")
yolo/utils/get_dataset.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+
4
+ import requests
5
+ from hydra import main
6
+ from loguru import logger
7
+ from tqdm import tqdm
8
+
9
+
10
+ def download_file(url, destination):
11
+ """
12
+ Downloads a file from the specified URL to the destination path with progress logging.
13
+ """
14
+ logger.info(f"Downloading {os.path.basename(destination)}...")
15
+ with requests.get(url, stream=True) as response:
16
+ response.raise_for_status()
17
+ total_size = int(response.headers.get("content-length", 0))
18
+ progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True)
19
+
20
+ with open(destination, "wb") as file:
21
+ for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
22
+ file.write(data)
23
+ progress.update(len(data))
24
+ progress.close()
25
+ logger.info("Download completed.")
26
+
27
+
28
+ def unzip_file(source, destination):
29
+ """
30
+ Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
31
+ """
32
+ logger.info(f"Unzipping {os.path.basename(source)}...")
33
+ with zipfile.ZipFile(source, "r") as zip_ref:
34
+ zip_ref.extractall(destination)
35
+ os.remove(source)
36
+ logger.info(f"Removed {source}.")
37
+
38
+
39
+ def check_files(directory, expected_count=None):
40
+ """
41
+ Returns True if the number of files in the directory matches expected_count, False otherwise.
42
+ """
43
+ files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
44
+ return len(files) == expected_count if expected_count is not None else bool(files)
45
+
46
+
47
+ @main(config_path="../config/data", config_name="download", version_base=None)
48
+ def prepare_dataset(cfg):
49
+ """
50
+ Prepares dataset by downloading and unzipping if necessary.
51
+ """
52
+ data_dir = cfg.save_path
53
+ for data_type, settings in cfg.datasets.items():
54
+ base_url = settings["base_url"]
55
+ for dataset_type, dataset_args in settings.items():
56
+ if dataset_type == "base_url":
57
+ continue # Skip the base_url entry
58
+ file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
59
+ url = f"{base_url}{file_name}"
60
+ local_zip_path = os.path.join(data_dir, file_name)
61
+ extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir
62
+ final_place = os.path.join(extract_to, dataset_type)
63
+
64
+ os.makedirs(extract_to, exist_ok=True)
65
+ if check_files(final_place, dataset_args.get("file_num")):
66
+ logger.info(f"Dataset {dataset_type} already verified.")
67
+ continue
68
+
69
+ if not os.path.exists(local_zip_path):
70
+ download_file(url, local_zip_path)
71
+ unzip_file(local_zip_path, extract_to)
72
+
73
+ if not check_files(final_place, dataset_args.get("file_num")):
74
+ logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ import sys
79
+
80
+ sys.path.append("./")
81
+ from tools.log_helper import custom_logger
82
+
83
+ custom_logger()
84
+ prepare_dataset()
yolo/utils/loss.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def get_loss_function(*args, **kwargs):
2
+ raise NotImplementedError