File size: 9,058 Bytes
1fe2937 c4cd90a 475302b 1197f7d 7d7e199 c4cd90a 7d7e199 c4cd90a 1197f7d 3e08dd8 1197f7d 475302b 3e08dd8 aba5422 0174b5b 1197f7d 4be6676 c0e2436 4be6676 c4cd90a 1d404e2 c4cd90a 1197f7d c4cd90a 1d404e2 1197f7d c4cd90a 1d404e2 1690354 1197f7d c4cd90a 1d404e2 1197f7d c4cd90a 1d404e2 1197f7d dcceddd 1197f7d 3e08dd8 b4dad5e 3e08dd8 604c897 4be6676 604c897 4be6676 c0e2436 604c897 5fcc6be 604c897 4be6676 c0e2436 5fcc6be c0e2436 5fcc6be 604c897 5b9da41 604c897 5b9da41 1197f7d dcceddd 1197f7d 3e08dd8 5b9da41 b4dad5e 3e08dd8 7d7e199 1fe2937 60c4943 1fe2937 2ee7407 4056352 ecc08c8 3b31306 475302b abc3992 475302b aba5422 d5a73bd 475302b aba5422 d5a73bd 475302b d5a73bd 475302b e0c8580 21a413f 475302b 21a413f 8b3b3ef 21a413f 475302b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import os
from copy import deepcopy
from math import exp
from pathlib import Path
from typing import List, Optional, Type, Union
import torch
import torch.distributed as dist
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from omegaconf import ListConfig
from torch import Tensor, no_grad
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
from yolo.model.yolo import YOLO
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, transform_bbox
from yolo.utils.logger import logger
def lerp(start: float, end: float, step: Union[int, float], total: int = 1):
"""
Linearly interpolates between start and end values.
start * (1 - step) + end * step
Parameters:
start (float): The starting value.
end (float): The ending value.
step (int): The current step in the interpolation process.
total (int): The total number of steps.
Returns:
float: The interpolated value.
"""
return start + (end - start) * step / total
class EMA(Callback):
def __init__(self, decay: float = 0.9999, tau: float = 2000):
super().__init__()
logger.info(":chart_with_upwards_trend: Enable Model EMA")
self.decay = decay
self.tau = tau
self.step = 0
self.ema_state_dict = None
def setup(self, trainer, pl_module, stage):
pl_module.ema = deepcopy(pl_module.model)
self.tau /= trainer.world_size
for param in pl_module.ema.parameters():
param.requires_grad = False
def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
if self.ema_state_dict is None:
self.ema_state_dict = deepcopy(pl_module.model.state_dict())
pl_module.ema.load_state_dict(self.ema_state_dict)
@no_grad()
def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
self.step += 1
decay_factor = self.decay * (1 - exp(-self.step / self.tau))
for key, param in pl_module.model.state_dict().items():
self.ema_state_dict[key] = lerp(param.detach(), self.ema_state_dict[key], decay_factor)
def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
"""Create an optimizer for the given model parameters based on the configuration.
Returns:
An instance of the optimizer configured according to the provided settings.
"""
optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)
bias_params = [p for name, p in model.named_parameters() if "bias" in name]
norm_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" in name]
conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]
model_parameters = [
{"params": bias_params, "momentum": 0.937, "weight_decay": 0},
{"params": conv_params, "momentum": 0.937},
{"params": norm_params, "momentum": 0.937, "weight_decay": 0},
]
def next_epoch(self, batch_num, epoch_idx):
self.min_lr = self.max_lr
self.max_lr = [param["lr"] for param in self.param_groups]
# TODO: load momentum from config instead a fix number
# 0.937: Start Momentum
# 0.8 : Normal Momemtum
# 3 : The warm up epoch num
self.min_mom = lerp(0.8, 0.937, min(epoch_idx, 3), 3)
self.max_mom = lerp(0.8, 0.937, min(epoch_idx + 1, 3), 3)
self.batch_num = batch_num
self.batch_idx = 0
def next_batch(self):
self.batch_idx += 1
lr_dict = dict()
for lr_idx, param_group in enumerate(self.param_groups):
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
param_group["lr"] = lerp(min_lr, max_lr, self.batch_idx, self.batch_num)
param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
lr_dict[f"momentum/{lr_idx}"] = param_group["momentum"]
return lr_dict
optimizer_class.next_batch = next_batch
optimizer_class.next_epoch = next_epoch
optimizer = optimizer_class(model_parameters, **optim_cfg.args)
optimizer.max_lr = [0.1, 0, 0]
return optimizer
def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LRScheduler:
"""Create a learning rate scheduler for the given optimizer based on the configuration.
Returns:
An instance of the scheduler configured according to the provided settings.
"""
scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler, schedule_cfg.type)
schedule = scheduler_class(optimizer, **schedule_cfg.args)
if hasattr(schedule_cfg, "warmup"):
wepoch = schedule_cfg.warmup.epochs
lambda1 = lambda epoch: (epoch + 1) / wepoch if epoch < wepoch else 1
lambda2 = lambda epoch: 10 - 9 * ((epoch + 1) / wepoch) if epoch < wepoch else 1
warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda2, lambda1, lambda1])
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[wepoch - 1])
return schedule
def initialize_distributed() -> None:
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
logger.info(f"🔢 Initialized process group; rank: {rank}, size: {world_size}")
return local_rank
def get_device(device_spec: Union[str, int, List[int]]) -> torch.device:
ddp_flag = False
if isinstance(device_spec, (list, ListConfig)):
ddp_flag = True
device_spec = initialize_distributed()
if torch.cuda.is_available() and "cuda" in str(device_spec):
return torch.device(device_spec), ddp_flag
if not torch.cuda.is_available():
if device_spec != "cpu":
logger.warning(f"❎ Device spec: {device_spec} not support, Choosing CPU instead")
return torch.device("cpu"), False
device = torch.device(device_spec)
return device, ddp_flag
class PostProcess:
"""
TODO: function document
scale back the prediction and do nms for pred_bbox
"""
def __init__(self, converter: Union[Vec2Box, Anc2Box], nms_cfg: NMSConfig) -> None:
self.converter = converter
self.nms = nms_cfg
def __call__(
self, predict, rev_tensor: Optional[Tensor] = None, image_size: Optional[List[int]] = None
) -> List[Tensor]:
if image_size is not None:
self.converter.update(image_size)
prediction = self.converter(predict["Main"])
pred_class, _, pred_bbox = prediction[:3]
pred_conf = prediction[3] if len(prediction) == 4 else None
if rev_tensor is not None:
pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf)
return pred_bbox
def collect_prediction(predict_json: List, local_rank: int) -> List:
"""
Collects predictions from all distributed processes and gathers them on the main process (rank 0).
Args:
predict_json (List): The prediction data (can be of any type) generated by the current process.
local_rank (int): The rank of the current process. Typically, rank 0 is the main process.
Returns:
List: The combined list of predictions from all processes if on rank 0, otherwise predict_json.
"""
if dist.is_initialized() and local_rank == 0:
all_predictions = [None for _ in range(dist.get_world_size())]
dist.gather_object(predict_json, all_predictions, dst=0)
predict_json = [item for sublist in all_predictions for item in sublist]
elif dist.is_initialized():
dist.gather_object(predict_json, None, dst=0)
return predict_json
def predicts_to_json(img_paths, predicts, rev_tensor):
"""
TODO: function document
turn a batch of imagepath and predicts(n x 6 for each image) to a List of diction(Detection output)
"""
batch_json = []
for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
scale, shift = box_reverse.split([1, 4])
bboxes = bboxes.clone()
bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
for cls, *pos, conf in bboxes:
bbox = {
"image_id": int(Path(img_path).stem),
"category_id": IDX_TO_ID[int(cls)],
"bbox": [float(p) for p in pos],
"score": float(conf),
}
batch_json.append(bbox)
return batch_json
|