henry000 commited on
Commit
cbbfcfe
Β·
1 Parent(s): 4b68a08

🚚 [Move] loss function to yolo/utils

Browse files

also remove duplicate files, and remove main function of loss, need to write pytest

config/config.py DELETED
@@ -1,107 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Dict, List, Union
3
-
4
-
5
- @dataclass
6
- class AnchorConfig:
7
- reg_max: int
8
- strides: List[int]
9
-
10
-
11
- @dataclass
12
- class Model:
13
- anchor: AnchorConfig
14
- model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
15
-
16
-
17
- @dataclass
18
- class Download:
19
- auto: bool
20
- path: str
21
-
22
-
23
- @dataclass
24
- class DataLoaderConfig:
25
- batch_size: int
26
- shuffle: bool
27
- num_workers: int
28
- pin_memory: bool
29
- image_size: List[int]
30
- class_num: int
31
-
32
-
33
- @dataclass
34
- class OptimizerArgs:
35
- lr: float
36
- weight_decay: float
37
-
38
-
39
- @dataclass
40
- class OptimizerConfig:
41
- type: str
42
- args: OptimizerArgs
43
-
44
-
45
- @dataclass
46
- class SchedulerArgs:
47
- step_size: int
48
- gamma: float
49
-
50
-
51
- @dataclass
52
- class SchedulerConfig:
53
- type: str
54
- args: SchedulerArgs
55
-
56
-
57
- @dataclass
58
- class EMAConfig:
59
- enabled: bool
60
- decay: float
61
-
62
-
63
- @dataclass
64
- class MatcherConfig:
65
- iou: str
66
- topk: int
67
- factor: Dict[str, int]
68
-
69
-
70
- @dataclass
71
- class TrainConfig:
72
- optimizer: OptimizerConfig
73
- scheduler: SchedulerConfig
74
- ema: EMAConfig
75
- matcher: MatcherConfig
76
-
77
-
78
- @dataclass
79
- class HyperConfig:
80
- data: DataLoaderConfig
81
- train: TrainConfig
82
-
83
-
84
- @dataclass
85
- class Dataset:
86
- file_name: str
87
- num_files: int
88
-
89
-
90
- @dataclass
91
- class Datasets:
92
- base_url: str
93
- images: Dict[str, Dataset]
94
-
95
-
96
- @dataclass
97
- class Download:
98
- auto: bool
99
- save_path: str
100
- datasets: Datasets
101
-
102
-
103
- @dataclass
104
- class Config:
105
- model: Model
106
- download: Download
107
- hyper: HyperConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/config.yaml DELETED
@@ -1,11 +0,0 @@
1
- hydra:
2
- run:
3
- dir: ./runs
4
-
5
- defaults:
6
- - data: coco
7
- - download: ../data/download
8
- - augmentation: ../data/augmentation
9
- - model: v7-base
10
- - hyper: default
11
- - _self_
 
 
 
 
 
 
 
 
 
 
 
 
config/data/augmentation.yaml DELETED
@@ -1,3 +0,0 @@
1
- Mosaic: 1
2
- # MixUp: 1
3
- HorizontalFlip: 0.5
 
 
 
 
config/data/download.yaml DELETED
@@ -1,21 +0,0 @@
1
- auto: True
2
- save_path: data/coco
3
- datasets:
4
- images:
5
- base_url: http://images.cocodataset.org/zips/
6
- train2017:
7
- file_name: train2017
8
- file_num: 118287
9
- val2017:
10
- file_name: val2017
11
- file_num: 5000
12
- test2017:
13
- file_name: test2017
14
- file_num: 40670
15
- annotations:
16
- base_url: http://images.cocodataset.org/annotations/
17
- annotations:
18
- file_name: annotations_trainval2017
19
- hydra:
20
- run:
21
- dir: ./runs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/hyper/default.yaml DELETED
@@ -1,35 +0,0 @@
1
- data:
2
- batch_size: 4
3
- shuffle: True
4
- num_workers: 4
5
- pin_memory: True
6
- class_num: 80
7
- image_size: [640, 640]
8
- train:
9
- optimizer:
10
- type: Adam
11
- args:
12
- lr: 0.001
13
- weight_decay: 0.0001
14
- loss:
15
- BCELoss:
16
- args:
17
- BoxLoss:
18
- args:
19
- alpha: 0.1
20
- DFLoss:
21
- args:
22
- matcher:
23
- iou: CIoU
24
- topk: 10
25
- factor:
26
- iou: 6.0
27
- cls: 0.5
28
- scheduler:
29
- type: StepLR
30
- args:
31
- step_size: 10
32
- gamma: 0.1
33
- ema:
34
- enabled: true
35
- decay: 0.995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/converter_json2txt.py DELETED
@@ -1,86 +0,0 @@
1
- import json
2
- import os
3
- from typing import Dict, List, Optional
4
-
5
- from tqdm import tqdm
6
-
7
-
8
- def discretize_categories(categories: List[Dict[str, int]]) -> Dict[int, int]:
9
- """
10
- Maps each unique 'id' in the list of category dictionaries to a sequential integer index.
11
- Indices are assigned based on the sorted 'id' values.
12
- """
13
- sorted_categories = sorted(categories, key=lambda category: category["id"])
14
- return {category["id"]: index for index, category in enumerate(sorted_categories)}
15
-
16
-
17
- def process_annotations(
18
- image_annotations: Dict[int, List[Dict]],
19
- image_info_dict: Dict[int, tuple],
20
- output_dir: str,
21
- id_to_idx: Optional[Dict[int, int]] = None,
22
- ) -> None:
23
- """
24
- Process and save annotations to files, with option to remap category IDs.
25
- """
26
- for image_id, annotations in tqdm(image_annotations.items(), desc="Processing annotations"):
27
- file_path = os.path.join(output_dir, f"{image_id:0>12}.txt")
28
- if not annotations:
29
- continue
30
- with open(file_path, "w") as file:
31
- for annotation in annotations:
32
- process_annotation(annotation, image_info_dict[image_id], id_to_idx, file)
33
-
34
-
35
- def process_annotation(annotation: Dict, image_dims: tuple, id_to_idx: Optional[Dict[int, int]], file) -> None:
36
- """
37
- Convert a single annotation's segmentation and write it to the open file handle.
38
- """
39
- category_id = annotation["category_id"]
40
- segmentation = (
41
- annotation["segmentation"][0]
42
- if annotation["segmentation"] and isinstance(annotation["segmentation"][0], list)
43
- else None
44
- )
45
-
46
- if segmentation is None:
47
- return
48
-
49
- img_width, img_height = image_dims
50
- normalized_segmentation = normalize_segmentation(segmentation, img_width, img_height)
51
-
52
- if id_to_idx:
53
- category_id = id_to_idx.get(category_id, category_id)
54
-
55
- file.write(f"{category_id} {' '.join(normalized_segmentation)}\n")
56
-
57
-
58
- def normalize_segmentation(segmentation: List[float], img_width: int, img_height: int) -> List[str]:
59
- """
60
- Normalize and format segmentation coordinates.
61
- """
62
- return [f"{x/img_width:.6f}" if i % 2 == 0 else f"{x/img_height:.6f}" for i, x in enumerate(segmentation)]
63
-
64
-
65
- def convert_annotations(json_file: str, output_dir: str) -> None:
66
- """
67
- Load annotation data from a JSON file and process all annotations.
68
- """
69
- with open(json_file) as file:
70
- data = json.load(file)
71
-
72
- os.makedirs(output_dir, exist_ok=True)
73
-
74
- image_info_dict = {img["id"]: (img["width"], img["height"]) for img in data.get("images", [])}
75
- id_to_idx = discretize_categories(data.get("categories", [])) if "categories" in data else None
76
- image_annotations = {img_id: [] for img_id in image_info_dict}
77
-
78
- for annotation in data.get("annotations", []):
79
- if not annotation.get("iscrowd", False):
80
- image_annotations[annotation["image_id"]].append(annotation)
81
-
82
- process_annotations(image_annotations, image_info_dict, output_dir, id_to_idx)
83
-
84
-
85
- convert_annotations("./data/coco/annotations/instances_train2017.json", "./data/coco/labels/train2017/")
86
- convert_annotations("./data/coco/annotations/instances_val2017.json", "./data/coco/labels/val2017/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/data_augment.py DELETED
@@ -1,125 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from PIL import Image
4
- from torchvision.transforms import functional as TF
5
-
6
-
7
- class Compose:
8
- """Composes several transforms together."""
9
-
10
- def __init__(self, transforms, image_size: int = 640):
11
- self.transforms = transforms
12
- self.image_size = image_size
13
-
14
- for transform in self.transforms:
15
- if hasattr(transform, "set_parent"):
16
- transform.set_parent(self)
17
-
18
- def __call__(self, image, boxes):
19
- for transform in self.transforms:
20
- image, boxes = transform(image, boxes)
21
- return image, boxes
22
-
23
-
24
- class HorizontalFlip:
25
- """Randomly horizontally flips the image along with the bounding boxes."""
26
-
27
- def __init__(self, prob=0.5):
28
- self.prob = prob
29
-
30
- def __call__(self, image, boxes):
31
- if torch.rand(1) < self.prob:
32
- image = TF.hflip(image)
33
- boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
34
- return image, boxes
35
-
36
-
37
- class VerticalFlip:
38
- """Randomly vertically flips the image along with the bounding boxes."""
39
-
40
- def __init__(self, prob=0.5):
41
- self.prob = prob
42
-
43
- def __call__(self, image, boxes):
44
- if torch.rand(1) < self.prob:
45
- image = TF.vflip(image)
46
- boxes[:, [2, 4]] = 1 - boxes[:, [4, 2]]
47
- return image, boxes
48
-
49
-
50
- class Mosaic:
51
- """Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""
52
-
53
- def __init__(self, prob=0.5):
54
- self.prob = prob
55
- self.parent = None
56
-
57
- def set_parent(self, parent):
58
- self.parent = parent
59
-
60
- def __call__(self, image, boxes):
61
- if torch.rand(1) >= self.prob:
62
- return image, boxes
63
-
64
- assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
65
-
66
- img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
67
- more_data = self.parent.get_more_data(3) # get 3 more images randomly
68
-
69
- data = [(image, boxes)] + more_data
70
- mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz))
71
- vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
72
- center = np.array([img_sz, img_sz])
73
- all_labels = []
74
-
75
- for (image, boxes), vector in zip(data, vectors):
76
- this_w, this_h = image.size
77
- coord = tuple(center + vector * np.array([this_w, this_h]))
78
-
79
- mosaic_image.paste(image, coord)
80
- xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
81
- xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
82
- xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
83
- ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
84
- ymax = (ymax * this_h + coord[1]) / (2 * img_sz)
85
-
86
- adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
87
- all_labels.append(adjusted_boxes)
88
-
89
- all_labels = torch.cat(all_labels, dim=0)
90
- mosaic_image = mosaic_image.resize((img_sz, img_sz))
91
- return mosaic_image, all_labels
92
-
93
-
94
- class MixUp:
95
- """Applies the MixUp augmentation to a pair of images and their corresponding boxes."""
96
-
97
- def __init__(self, prob=0.5, alpha=1.0):
98
- self.alpha = alpha
99
- self.prob = prob
100
- self.parent = None
101
-
102
- def set_parent(self, parent):
103
- """Set the parent dataset object for accessing dataset methods."""
104
- self.parent = parent
105
-
106
- def __call__(self, image, boxes):
107
- if torch.rand(1) >= self.prob:
108
- return image, boxes
109
-
110
- assert self.parent is not None, "Parent is not set. MixUp cannot retrieve additional data."
111
-
112
- # Retrieve another image and its boxes randomly from the dataset
113
- image2, boxes2 = self.parent.get_more_data()[0]
114
-
115
- # Calculate the mixup lambda parameter
116
- lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
117
-
118
- # Mix images
119
- image1, image2 = TF.to_tensor(image), TF.to_tensor(image2)
120
- mixed_image = lam * image1 + (1 - lam) * image2
121
-
122
- # Mix bounding boxes
123
- mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
124
-
125
- return TF.to_pil_image(mixed_image), mixed_boxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/dataloader.py DELETED
@@ -1,206 +0,0 @@
1
- import os
2
- from os import path
3
- from typing import List, Tuple, Union
4
-
5
- import diskcache as dc
6
- import hydra
7
- import numpy as np
8
- import torch
9
- from loguru import logger
10
- from PIL import Image
11
- from torch.utils.data import DataLoader, Dataset
12
- from torchvision.transforms import functional as TF
13
- from tqdm.rich import tqdm
14
-
15
- from tools.dataset_helper import (
16
- create_image_info_dict,
17
- find_labels_path,
18
- get_scaled_segmentation,
19
- )
20
- from utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
21
- from utils.drawer import draw_bboxes
22
-
23
-
24
- class YoloDataset(Dataset):
25
- def __init__(self, config: dict, phase: str = "train2017", image_size: int = 640):
26
- dataset_cfg = config.data
27
- augment_cfg = config.augmentation
28
- phase_name = dataset_cfg.get(phase, phase)
29
- self.image_size = image_size
30
-
31
- transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
32
- self.transform = Compose(transforms, self.image_size)
33
- self.transform.get_more_data = self.get_more_data
34
- self.data = self.load_data(dataset_cfg.path, phase_name)
35
-
36
- def load_data(self, dataset_path, phase_name):
37
- """
38
- Loads data from a cache or generates a new cache for a specific dataset phase.
39
-
40
- Parameters:
41
- dataset_path (str): The root path to the dataset directory.
42
- phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.
43
-
44
- Returns:
45
- dict: The loaded data from the cache for the specified phase.
46
- """
47
- cache_path = path.join(dataset_path, ".cache")
48
- cache = dc.Cache(cache_path)
49
- data = cache.get(phase_name)
50
-
51
- if data is None:
52
- logger.info("Generating {} cache", phase_name)
53
- data = self.filter_data(dataset_path, phase_name)
54
- cache[phase_name] = data
55
-
56
- cache.close()
57
- logger.info("πŸ“¦ Loaded {} cache", phase_name)
58
- data = cache[phase_name]
59
- return data
60
-
61
- def filter_data(self, dataset_path: str, phase_name: str) -> list:
62
- """
63
- Filters and collects dataset information by pairing images with their corresponding labels.
64
-
65
- Parameters:
66
- images_path (str): Path to the directory containing image files.
67
- labels_path (str): Path to the directory containing label files.
68
-
69
- Returns:
70
- list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
71
- """
72
- images_path = path.join(dataset_path, "images", phase_name)
73
- labels_path, data_type = find_labels_path(dataset_path, phase_name)
74
- images_list = sorted(os.listdir(images_path))
75
- if data_type == "json":
76
- annotations_index, image_info_dict = create_image_info_dict(labels_path)
77
-
78
- data = []
79
- valid_inputs = 0
80
- for image_name in tqdm(images_list, desc="Filtering data"):
81
- if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
82
- continue
83
- image_id, _ = path.splitext(image_name)
84
-
85
- if data_type == "json":
86
- image_info = image_info_dict.get(image_id, None)
87
- if image_info is None:
88
- continue
89
- annotations = annotations_index.get(image_info["id"], [])
90
- image_seg_annotations = get_scaled_segmentation(annotations, image_info)
91
- if not image_seg_annotations:
92
- continue
93
-
94
- elif data_type == "txt":
95
- label_path = path.join(labels_path, f"{image_id}.txt")
96
- if not path.isfile(label_path):
97
- continue
98
- with open(label_path, "r") as file:
99
- image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
100
-
101
- labels = self.load_valid_labels(image_id, image_seg_annotations)
102
- if labels is not None:
103
- img_path = path.join(images_path, image_name)
104
- data.append((img_path, labels))
105
- valid_inputs += 1
106
-
107
- logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
108
- return data
109
-
110
- def load_valid_labels(self, label_path, seg_data_one_img) -> Union[torch.Tensor, None]:
111
- """
112
- Loads and validates bounding box data is [0, 1] from a label file.
113
-
114
- Parameters:
115
- label_path (str): The filepath to the label file containing bounding box data.
116
-
117
- Returns:
118
- torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
119
- """
120
- bboxes = []
121
- for seg_data in seg_data_one_img:
122
- cls = seg_data[0]
123
- points = np.array(seg_data[1:]).reshape(-1, 2)
124
- valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
125
- if valid_points.size > 1:
126
- bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
127
- bboxes.append(bbox)
128
-
129
- if bboxes:
130
- return torch.stack(bboxes)
131
- else:
132
- logger.warning("No valid BBox in {}", label_path)
133
- return None
134
-
135
- def get_data(self, idx):
136
- img_path, bboxes = self.data[idx]
137
- img = Image.open(img_path).convert("RGB")
138
- return img, bboxes
139
-
140
- def get_more_data(self, num: int = 1):
141
- indices = torch.randint(0, len(self), (num,))
142
- return [self.get_data(idx) for idx in indices]
143
-
144
- def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
145
- img, bboxes = self.get_data(idx)
146
- if self.transform:
147
- img, bboxes = self.transform(img, bboxes)
148
- img = TF.to_tensor(img)
149
- return img, bboxes
150
-
151
- def __len__(self) -> int:
152
- return len(self.data)
153
-
154
-
155
- class YoloDataLoader(DataLoader):
156
- def __init__(self, config: dict):
157
- """Initializes the YoloDataLoader with hydra-config files."""
158
- hyper = config.hyper.data
159
- dataset = YoloDataset(config)
160
-
161
- super().__init__(
162
- dataset,
163
- batch_size=hyper.batch_size,
164
- shuffle=hyper.shuffle,
165
- num_workers=hyper.num_workers,
166
- pin_memory=hyper.pin_memory,
167
- collate_fn=self.collate_fn,
168
- )
169
-
170
- def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
171
- """
172
- A collate function to handle batching of images and their corresponding targets.
173
-
174
- Args:
175
- batch (list of tuples): Each tuple contains:
176
- - image (torch.Tensor): The image tensor.
177
- - labels (torch.Tensor): The tensor of labels for the image.
178
-
179
- Returns:
180
- Tuple[torch.Tensor, List[torch.Tensor]]: A tuple containing:
181
- - A tensor of batched images.
182
- - A list of tensors, each corresponding to bboxes for each image in the batch.
183
- """
184
- images = torch.stack([item[0] for item in batch])
185
- targets = [item[1] for item in batch]
186
- return images, targets
187
-
188
-
189
- def get_dataloader(config):
190
- return YoloDataLoader(config)
191
-
192
-
193
- @hydra.main(config_path="../config", config_name="config", version_base=None)
194
- def main(cfg):
195
- dataloader = get_dataloader(cfg)
196
- draw_bboxes(*next(iter(dataloader)))
197
-
198
-
199
- if __name__ == "__main__":
200
- import sys
201
-
202
- sys.path.append("./")
203
- from tools.log_helper import custom_logger
204
-
205
- custom_logger()
206
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/drawer.py DELETED
@@ -1,41 +0,0 @@
1
- from typing import List, Union
2
-
3
- import torch
4
- from loguru import logger
5
- from PIL import Image, ImageDraw, ImageFont
6
- from torchvision.transforms.functional import to_pil_image
7
-
8
-
9
- def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
10
- """
11
- Draw bounding boxes on an image.
12
-
13
- Args:
14
- - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
15
- - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
16
- where coordinates are normalized [0, 1].
17
- """
18
- # Convert tensor image to PIL Image if necessary
19
- if isinstance(img, torch.Tensor):
20
- if img.dim() > 3:
21
- logger.info("Multi-frame tensor detected, using the first image.")
22
- img = img[0]
23
- bboxes = bboxes[0]
24
- img = to_pil_image(img)
25
-
26
- draw = ImageDraw.Draw(img)
27
- width, height = img.size
28
- font = ImageFont.load_default(30)
29
-
30
- for bbox in bboxes:
31
- class_id, x_min, y_min, x_max, y_max = bbox
32
- x_min = x_min * width
33
- x_max = x_max * width
34
- y_min = y_min * height
35
- y_max = y_max * height
36
- shape = [(x_min, y_min), (x_max, y_max)]
37
- draw.rectangle(shape, outline="red", width=3)
38
- draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
39
-
40
- img.save("visualize.jpg") # Save the image with annotations
41
- logger.info("Saved visualize image at visualize.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/get_dataset.py DELETED
@@ -1,84 +0,0 @@
1
- import os
2
- import zipfile
3
-
4
- import requests
5
- from hydra import main
6
- from loguru import logger
7
- from tqdm import tqdm
8
-
9
-
10
- def download_file(url, destination):
11
- """
12
- Downloads a file from the specified URL to the destination path with progress logging.
13
- """
14
- logger.info(f"Downloading {os.path.basename(destination)}...")
15
- with requests.get(url, stream=True) as response:
16
- response.raise_for_status()
17
- total_size = int(response.headers.get("content-length", 0))
18
- progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True)
19
-
20
- with open(destination, "wb") as file:
21
- for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
22
- file.write(data)
23
- progress.update(len(data))
24
- progress.close()
25
- logger.info("Download completed.")
26
-
27
-
28
- def unzip_file(source, destination):
29
- """
30
- Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
31
- """
32
- logger.info(f"Unzipping {os.path.basename(source)}...")
33
- with zipfile.ZipFile(source, "r") as zip_ref:
34
- zip_ref.extractall(destination)
35
- os.remove(source)
36
- logger.info(f"Removed {source}.")
37
-
38
-
39
- def check_files(directory, expected_count=None):
40
- """
41
- Returns True if the number of files in the directory matches expected_count, False otherwise.
42
- """
43
- files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
44
- return len(files) == expected_count if expected_count is not None else bool(files)
45
-
46
-
47
- @main(config_path="../config/data", config_name="download", version_base=None)
48
- def prepare_dataset(cfg):
49
- """
50
- Prepares dataset by downloading and unzipping if necessary.
51
- """
52
- data_dir = cfg.save_path
53
- for data_type, settings in cfg.datasets.items():
54
- base_url = settings["base_url"]
55
- for dataset_type, dataset_args in settings.items():
56
- if dataset_type == "base_url":
57
- continue # Skip the base_url entry
58
- file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
59
- url = f"{base_url}{file_name}"
60
- local_zip_path = os.path.join(data_dir, file_name)
61
- extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir
62
- final_place = os.path.join(extract_to, dataset_type)
63
-
64
- os.makedirs(extract_to, exist_ok=True)
65
- if check_files(final_place, dataset_args.get("file_num")):
66
- logger.info(f"Dataset {dataset_type} already verified.")
67
- continue
68
-
69
- if not os.path.exists(local_zip_path):
70
- download_file(url, local_zip_path)
71
- unzip_file(local_zip_path, extract_to)
72
-
73
- if not check_files(final_place, dataset_args.get("file_num")):
74
- logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
75
-
76
-
77
- if __name__ == "__main__":
78
- import sys
79
-
80
- sys.path.append("./")
81
- from tools.log_helper import custom_logger
82
-
83
- custom_logger()
84
- prepare_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/loss.py DELETED
@@ -1,184 +0,0 @@
1
- import sys
2
- import time
3
- from typing import Any, List
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn.functional as F
8
- from einops import rearrange
9
- from hydra import main
10
- from loguru import logger
11
- from torch import Tensor, nn
12
- from torch.nn import BCEWithLogitsLoss
13
-
14
- sys.path.append("./")
15
- from config.config import Config
16
- from tools.bbox_helper import BoxMatcher, calculate_iou, make_anchor, transform_bbox
17
-
18
-
19
- def get_loss_function(*args, **kwargs):
20
- raise NotImplementedError
21
-
22
-
23
- class BCELoss(nn.Module):
24
- def __init__(self) -> None:
25
- super().__init__()
26
- self.bce = BCEWithLogitsLoss(pos_weight=torch.tensor([1.0], device=torch.device("cuda")), reduction="none")
27
-
28
- def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any:
29
- return self.bce(predicts_cls, targets_cls).sum() / cls_norm
30
-
31
-
32
- class BoxLoss(nn.Module):
33
- def __init__(self) -> None:
34
- super().__init__()
35
-
36
- def forward(
37
- self, predicts_bbox: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
38
- ) -> Any:
39
- valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
40
- picked_predict = predicts_bbox[valid_bbox].view(-1, 4)
41
- picked_targets = targets_bbox[valid_bbox].view(-1, 4)
42
-
43
- iou = calculate_iou(picked_predict, picked_targets, "ciou").diag()
44
- loss_iou = 1.0 - iou
45
- loss_iou = (loss_iou * box_norm).sum() / cls_norm
46
- return loss_iou
47
-
48
-
49
- class DFLoss(nn.Module):
50
- def __init__(self, anchors: Tensor, scaler: Tensor, reg_max: int) -> None:
51
- super().__init__()
52
- self.anchors = anchors
53
- self.scaler = scaler
54
- self.reg_max = reg_max
55
-
56
- def forward(
57
- self, predicts_anc: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
58
- ) -> Any:
59
- valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
60
- bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
61
- anchors_norm = (self.anchors / self.scaler[:, None])[None]
62
- targets_dist = torch.cat(((anchors_norm - bbox_lt), (bbox_rb - anchors_norm)), -1).clamp(0, self.reg_max - 1.01)
63
- picked_targets = targets_dist[valid_bbox].view(-1)
64
- picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)
65
-
66
- label_left, label_right = picked_targets.floor(), picked_targets.floor() + 1
67
- weight_left, weight_right = label_right - picked_targets, picked_targets - label_left
68
-
69
- loss_left = F.cross_entropy(picked_predict, label_left.to(torch.long), reduction="none")
70
- loss_right = F.cross_entropy(picked_predict, label_right.to(torch.long), reduction="none")
71
- loss_dfl = loss_left * weight_left + loss_right * weight_right
72
- loss_dfl = loss_dfl.view(-1, 4).mean(-1)
73
- loss_dfl = (loss_dfl * box_norm).sum() / cls_norm
74
- return loss_dfl
75
-
76
-
77
- class YOLOLoss:
78
- def __init__(self, cfg: Config) -> None:
79
- self.reg_max = cfg.model.anchor.reg_max
80
- self.class_num = cfg.hyper.data.class_num
81
- self.image_size = list(cfg.hyper.data.image_size)
82
- self.strides = cfg.model.anchor.strides
83
- device = torch.device("cuda")
84
-
85
- self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float16, device=device)
86
- self.scale_up = torch.tensor(self.image_size * 2, device=device)
87
-
88
- self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
89
-
90
- self.cls = BCELoss()
91
- self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
92
- self.iou = BoxLoss()
93
-
94
- self.matcher = BoxMatcher(cfg.hyper.train.matcher, self.class_num, self.anchors)
95
-
96
- def parse_predicts(self, predicts: List[Tensor]) -> Tensor:
97
- """
98
- args:
99
- [B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
100
- return:
101
- [B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
102
- """
103
- preds = []
104
- for pred in predicts:
105
- preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
106
- preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
107
-
108
- preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
109
- preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
110
-
111
- pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
112
-
113
- lt, rb = pred_LTRB.chunk(2, dim=-1)
114
- pred_minXY = self.anchors - lt
115
- pred_maxXY = self.anchors + rb
116
- predicts = torch.cat([preds_cls, pred_minXY, pred_maxXY], dim=-1)
117
-
118
- return predicts, preds_anc
119
-
120
- def parse_targets(self, targets: Tensor, batch_size: int = 16) -> List[Tensor]:
121
- """
122
- return List:
123
- """
124
- targets[:, 2:] = transform_bbox(targets[:, 2:], "xycwh -> xyxy") * self.scale_up
125
- bbox_num = targets[:, 0].int().bincount()
126
- batch_targets = torch.zeros(batch_size, bbox_num.max(), 5, device=targets.device)
127
- for instance_idx, bbox_num in enumerate(bbox_num):
128
- instance_targets = targets[targets[:, 0] == instance_idx]
129
- batch_targets[instance_idx, :bbox_num] = instance_targets[:, 1:].detach()
130
- return batch_targets
131
-
132
- def separate_anchor(self, anchors):
133
- """
134
- separate anchor and bbouding box
135
- """
136
- anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1)
137
- anchors_box = anchors_box / self.scaler[None, :, None]
138
- return anchors_cls, anchors_box
139
-
140
- @torch.autocast("cuda")
141
- def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tensor:
142
- # Batch_Size x (Anchor + Class) x H x W
143
- tlist = [time.time()]
144
- # TODO: check datatype, why targets has a little bit error with origin version
145
- predicts, predicts_anc = self.parse_predicts(predicts[0])
146
- targets = self.parse_targets(targets)
147
-
148
- align_targets, valid_masks = self.matcher(targets, predicts)
149
- # calculate loss between with instance and predict
150
-
151
- targets_cls, targets_bbox = self.separate_anchor(align_targets)
152
- predicts_cls, predicts_bbox = self.separate_anchor(predicts)
153
-
154
- cls_norm = targets_cls.sum()
155
- box_norm = targets_cls.sum(-1)[valid_masks]
156
-
157
- ## -- CLS -- ##
158
- loss_cls = self.cls(predicts_cls, targets_cls, cls_norm)
159
- ## -- IOU -- ##
160
- loss_iou = self.iou(predicts_bbox, targets_bbox, valid_masks, box_norm, cls_norm)
161
- ## -- DFL -- ##
162
- loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
163
-
164
- logger.info("Loss IoU: {:.5f}, DFL: {:.5f}, CLS: {:.5f}", loss_iou, loss_dfl, loss_cls)
165
- tlist.append(time.time())
166
- logger.info(f"Calculate Loss Run Time {np.diff(np.array(tlist)) * 1e3} ms")
167
-
168
-
169
- @main(config_path="../config", config_name="config", version_base=None)
170
- def main(cfg):
171
- losser = YOLOLoss(cfg)
172
- targets = torch.load("targets.pt")
173
- predicts = torch.load("predicts.pt")
174
- losser(predicts, targets)
175
-
176
-
177
- if __name__ == "__main__":
178
- import sys
179
-
180
- sys.path.append("./")
181
- from tools.log_helper import custom_logger
182
-
183
- custom_logger()
184
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolo/config/config.py CHANGED
@@ -2,9 +2,15 @@ from dataclasses import dataclass
2
  from typing import Dict, List, Union
3
 
4
 
 
 
 
 
 
 
5
  @dataclass
6
  class Model:
7
- anchor: List[List[int]]
8
  model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
9
 
10
 
@@ -20,6 +26,8 @@ class DataLoaderConfig:
20
  shuffle: bool
21
  num_workers: int
22
  pin_memory: bool
 
 
23
 
24
 
25
  @dataclass
@@ -52,11 +60,19 @@ class EMAConfig:
52
  decay: float
53
 
54
 
 
 
 
 
 
 
 
55
  @dataclass
56
  class TrainConfig:
57
  optimizer: OptimizerConfig
58
  scheduler: SchedulerConfig
59
  ema: EMAConfig
 
60
 
61
 
62
  @dataclass
 
2
  from typing import Dict, List, Union
3
 
4
 
5
+ @dataclass
6
+ class AnchorConfig:
7
+ reg_max: int
8
+ strides: List[int]
9
+
10
+
11
  @dataclass
12
  class Model:
13
+ anchor: AnchorConfig
14
  model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
15
 
16
 
 
26
  shuffle: bool
27
  num_workers: int
28
  pin_memory: bool
29
+ image_size: List[int]
30
+ class_num: int
31
 
32
 
33
  @dataclass
 
60
  decay: float
61
 
62
 
63
+ @dataclass
64
+ class MatcherConfig:
65
+ iou: str
66
+ topk: int
67
+ factor: Dict[str, int]
68
+
69
+
70
  @dataclass
71
  class TrainConfig:
72
  optimizer: OptimizerConfig
73
  scheduler: SchedulerConfig
74
  ema: EMAConfig
75
+ matcher: MatcherConfig
76
 
77
 
78
  @dataclass
yolo/config/hyper/default.yaml CHANGED
@@ -3,12 +3,28 @@ data:
3
  shuffle: True
4
  num_workers: 4
5
  pin_memory: True
 
 
6
  train:
7
  optimizer:
8
  type: Adam
9
  args:
10
  lr: 0.001
11
  weight_decay: 0.0001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  scheduler:
13
  type: StepLR
14
  args:
 
3
  shuffle: True
4
  num_workers: 4
5
  pin_memory: True
6
+ class_num: 80
7
+ image_size: [640, 640]
8
  train:
9
  optimizer:
10
  type: Adam
11
  args:
12
  lr: 0.001
13
  weight_decay: 0.0001
14
+ loss:
15
+ BCELoss:
16
+ args:
17
+ BoxLoss:
18
+ args:
19
+ alpha: 0.1
20
+ DFLoss:
21
+ args:
22
+ matcher:
23
+ iou: CIoU
24
+ topk: 10
25
+ factor:
26
+ iou: 6.0
27
+ cls: 0.5
28
  scheduler:
29
  type: StepLR
30
  args:
yolo/tools/bbox_helper.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import torch.nn.functional as F
6
  from torch import Tensor
7
 
8
- from config.config import MatcherConfig
9
 
10
 
11
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
 
5
  import torch.nn.functional as F
6
  from torch import Tensor
7
 
8
+ from yolo.config.config import MatcherConfig
9
 
10
 
11
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
yolo/utils/loss.py CHANGED
@@ -1,2 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def get_loss_function(*args, **kwargs):
2
  raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Any, List, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from hydra import main
8
+ from loguru import logger
9
+ from torch import Tensor, nn
10
+ from torch.nn import BCEWithLogitsLoss
11
+
12
+ from yolo.config.config import Config
13
+ from yolo.tools.bbox_helper import (
14
+ BoxMatcher,
15
+ calculate_iou,
16
+ make_anchor,
17
+ transform_bbox,
18
+ )
19
+
20
+
21
  def get_loss_function(*args, **kwargs):
22
  raise NotImplementedError
23
+
24
+
25
+ class BCELoss(nn.Module):
26
+ def __init__(self) -> None:
27
+ super().__init__()
28
+ self.bce = BCEWithLogitsLoss(pos_weight=torch.tensor([1.0], device=torch.device("cuda")), reduction="none")
29
+
30
+ def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any:
31
+ return self.bce(predicts_cls, targets_cls).sum() / cls_norm
32
+
33
+
34
+ class BoxLoss(nn.Module):
35
+ def __init__(self) -> None:
36
+ super().__init__()
37
+
38
+ def forward(
39
+ self, predicts_bbox: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
40
+ ) -> Any:
41
+ valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
42
+ picked_predict = predicts_bbox[valid_bbox].view(-1, 4)
43
+ picked_targets = targets_bbox[valid_bbox].view(-1, 4)
44
+
45
+ iou = calculate_iou(picked_predict, picked_targets, "ciou").diag()
46
+ loss_iou = 1.0 - iou
47
+ loss_iou = (loss_iou * box_norm).sum() / cls_norm
48
+ return loss_iou
49
+
50
+
51
+ class DFLoss(nn.Module):
52
+ def __init__(self, anchors: Tensor, scaler: Tensor, reg_max: int) -> None:
53
+ super().__init__()
54
+ self.anchors = anchors
55
+ self.scaler = scaler
56
+ self.reg_max = reg_max
57
+
58
+ def forward(
59
+ self, predicts_anc: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
60
+ ) -> Any:
61
+ valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
62
+ bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
63
+ anchors_norm = (self.anchors / self.scaler[:, None])[None]
64
+ targets_dist = torch.cat(((anchors_norm - bbox_lt), (bbox_rb - anchors_norm)), -1).clamp(0, self.reg_max - 1.01)
65
+ picked_targets = targets_dist[valid_bbox].view(-1)
66
+ picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)
67
+
68
+ label_left, label_right = picked_targets.floor(), picked_targets.floor() + 1
69
+ weight_left, weight_right = label_right - picked_targets, picked_targets - label_left
70
+
71
+ loss_left = F.cross_entropy(picked_predict, label_left.to(torch.long), reduction="none")
72
+ loss_right = F.cross_entropy(picked_predict, label_right.to(torch.long), reduction="none")
73
+ loss_dfl = loss_left * weight_left + loss_right * weight_right
74
+ loss_dfl = loss_dfl.view(-1, 4).mean(-1)
75
+ loss_dfl = (loss_dfl * box_norm).sum() / cls_norm
76
+ return loss_dfl
77
+
78
+
79
+ class YOLOLoss:
80
+ def __init__(self, cfg: Config) -> None:
81
+ self.reg_max = cfg.model.anchor.reg_max
82
+ self.class_num = cfg.hyper.data.class_num
83
+ self.image_size = list(cfg.hyper.data.image_size)
84
+ self.strides = cfg.model.anchor.strides
85
+ device = torch.device("cuda")
86
+
87
+ self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float16, device=device)
88
+ self.scale_up = torch.tensor(self.image_size * 2, device=device)
89
+
90
+ self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
91
+
92
+ self.cls = BCELoss()
93
+ self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
94
+ self.iou = BoxLoss()
95
+
96
+ self.matcher = BoxMatcher(cfg.hyper.train.matcher, self.class_num, self.anchors)
97
+
98
+ def parse_predicts(self, predicts: List[Tensor]) -> Tensor:
99
+ """
100
+ args:
101
+ [B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
102
+ return:
103
+ [B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
104
+ """
105
+ preds = []
106
+ for pred in predicts:
107
+ preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
108
+ preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
109
+
110
+ preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
111
+ preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
112
+
113
+ pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
114
+
115
+ lt, rb = pred_LTRB.chunk(2, dim=-1)
116
+ pred_minXY = self.anchors - lt
117
+ pred_maxXY = self.anchors + rb
118
+ predicts = torch.cat([preds_cls, pred_minXY, pred_maxXY], dim=-1)
119
+
120
+ return predicts, preds_anc
121
+
122
+ def parse_targets(self, targets: Tensor, batch_size: int = 16) -> List[Tensor]:
123
+ """
124
+ return List:
125
+ """
126
+ targets[:, 2:] = transform_bbox(targets[:, 2:], "xycwh -> xyxy") * self.scale_up
127
+ bbox_num = targets[:, 0].int().bincount()
128
+ batch_targets = torch.zeros(batch_size, bbox_num.max(), 5, device=targets.device)
129
+ for instance_idx, bbox_num in enumerate(bbox_num):
130
+ instance_targets = targets[targets[:, 0] == instance_idx]
131
+ batch_targets[instance_idx, :bbox_num] = instance_targets[:, 1:].detach()
132
+ return batch_targets
133
+
134
+ def separate_anchor(self, anchors):
135
+ """
136
+ separate anchor and bbouding box
137
+ """
138
+ anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1)
139
+ anchors_box = anchors_box / self.scaler[None, :, None]
140
+ return anchors_cls, anchors_box
141
+
142
+ @torch.autocast("cuda")
143
+ def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
144
+ # Batch_Size x (Anchor + Class) x H x W
145
+ # TODO: check datatype, why targets has a little bit error with origin version
146
+ predicts, predicts_anc = self.parse_predicts(predicts[0])
147
+ targets = self.parse_targets(targets)
148
+
149
+ align_targets, valid_masks = self.matcher(targets, predicts)
150
+ # calculate loss between with instance and predict
151
+
152
+ targets_cls, targets_bbox = self.separate_anchor(align_targets)
153
+ predicts_cls, predicts_bbox = self.separate_anchor(predicts)
154
+
155
+ cls_norm = targets_cls.sum()
156
+ box_norm = targets_cls.sum(-1)[valid_masks]
157
+
158
+ ## -- CLS -- ##
159
+ loss_cls = self.cls(predicts_cls, targets_cls, cls_norm)
160
+ ## -- IOU -- ##
161
+ loss_iou = self.iou(predicts_bbox, targets_bbox, valid_masks, box_norm, cls_norm)
162
+ ## -- DFL -- ##
163
+ loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
164
+
165
+ logger.info("Loss IoU: {:.5f}, DFL: {:.5f}, CLS: {:.5f}", loss_iou, loss_dfl, loss_cls)
166
+ return loss_iou, loss_dfl, loss_cls