""" |
Train a model on a dataset |
Usage: |
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16 |
""" |
import math |
import os |
import subprocess |
import time |
from copy import deepcopy |
from datetime import datetime, timedelta |
from pathlib import Path |
import numpy as np |
import torch |
from torch import distributed as dist |
from torch import nn, optim |
from torch.cuda import amp |
from torch.nn.parallel import DistributedDataParallel as DDP |
from tqdm import tqdm |
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights |
from ultralytics.yolo.cfg import get_cfg |
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset |
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, |
clean_url, colorstr, emojis, yaml_save) |
from ultralytics.yolo.utils.autobatch import check_train_batch_size |
from ultralytics.yolo.utils.checks import check_amp, check_file, check_imgsz, print_args |
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command |
from ultralytics.yolo.utils.files import get_latest_run, increment_path |
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, |
select_device, strip_optimizer) |
class BaseTrainer: |
""" |
BaseTrainer |
A base class for creating trainers. |
Attributes: |
args (SimpleNamespace): Configuration for the trainer. |
check_resume (method): Method to check if training should be resumed from a saved checkpoint. |
validator (BaseValidator): Validator instance. |
model (nn.Module): Model instance. |
callbacks (defaultdict): Dictionary of callbacks. |
save_dir (Path): Directory to save results. |
wdir (Path): Directory to save weights. |
last (Path): Path to last checkpoint. |
best (Path): Path to best checkpoint. |
save_period (int): Save checkpoint every x epochs (disabled if < 1). |
batch_size (int): Batch size for training. |
epochs (int): Number of epochs to train for. |
start_epoch (int): Starting epoch for training. |
device (torch.device): Device to use for training. |
amp (bool): Flag to enable AMP (Automatic Mixed Precision). |
scaler (amp.GradScaler): Gradient scaler for AMP. |
data (str): Path to data. |
trainset (torch.utils.data.Dataset): Training dataset. |
testset (torch.utils.data.Dataset): Testing dataset. |
ema (nn.Module): EMA (Exponential Moving Average) of the model. |
lf (nn.Module): Loss function. |
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. |
best_fitness (float): The best fitness value achieved. |
fitness (float): Current fitness value. |
loss (float): Current loss value. |
tloss (float): Total loss value. |
loss_names (list): List of loss names. |
csv (Path): Path to results CSV file. |
""" |
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): |
""" |
Initializes the BaseTrainer class. |
Args: |
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. |
overrides (dict, optional): Configuration overrides. Defaults to None. |
""" |
self.args = get_cfg(cfg, overrides) |
self.device = select_device(self.args.device, self.args.batch) |
self.check_resume() |
self.validator = None |
self.model = None |
self.metrics = None |
self.plots = {} |
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) |
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task |
name = self.args.name or f'{self.args.mode}' |
if hasattr(self.args, 'save_dir'): |
self.save_dir = Path(self.args.save_dir) |
else: |
self.save_dir = Path( |
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True)) |
self.wdir = self.save_dir / 'weights' |
if RANK in (-1, 0): |
self.wdir.mkdir(parents=True, exist_ok=True) |
self.args.save_dir = str(self.save_dir) |
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) |
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' |
self.save_period = self.args.save_period |
self.batch_size = self.args.batch |
self.epochs = self.args.epochs |
self.start_epoch = 0 |
if RANK == -1: |
print_args(vars(self.args)) |
if self.device.type == 'cpu': |
self.args.workers = 0 |
self.model = self.args.model |
try: |
if self.args.task == 'classify': |
self.data = check_cls_dataset(self.args.data) |
elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'): |
self.data = check_det_dataset(self.args.data) |
if 'yaml_file' in self.data: |
self.args.data = self.data['yaml_file'] |
except Exception as e: |
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e |
self.trainset, self.testset = self.get_dataset(self.data) |
self.ema = None |
self.lf = None |
self.scheduler = None |
self.best_fitness = None |
self.fitness = None |
self.loss = None |
self.tloss = None |
self.loss_names = ['Loss'] |
self.csv = self.save_dir / 'results.csv' |
self.plot_idx = [0, 1, 2] |
self.callbacks = _callbacks or callbacks.get_default_callbacks() |
if RANK in (-1, 0): |
callbacks.add_integration_callbacks(self) |
def add_callback(self, event: str, callback): |
""" |
Appends the given callback. |
""" |
self.callbacks[event].append(callback) |
def set_callback(self, event: str, callback): |
""" |
Overrides the existing callbacks with the given callback. |
""" |
self.callbacks[event] = [callback] |
def run_callbacks(self, event: str): |
"""Run all existing callbacks associated with a particular event.""" |
for callback in self.callbacks.get(event, []): |
callback(self) |
def train(self): |
"""Allow device='', device=None on Multi-GPU systems to default to device=0.""" |
if isinstance(self.args.device, int) or self.args.device: |
world_size = torch.cuda.device_count() |
elif torch.cuda.is_available(): |
world_size = 1 |
else: |
world_size = 0 |
if world_size > 1 and 'LOCAL_RANK' not in os.environ: |
if self.args.rect: |
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting rect=False") |
self.args.rect = False |
cmd, file = generate_ddp_command(world_size, self) |
try: |
LOGGER.info(f'DDP command: {cmd}') |
subprocess.run(cmd, check=True) |
except Exception as e: |
raise e |
finally: |
ddp_cleanup(self, str(file)) |
else: |
self._do_train(world_size) |
def _setup_ddp(self, world_size): |
"""Initializes and sets the DistributedDataParallel parameters for training.""" |
torch.cuda.set_device(RANK) |
self.device = torch.device('cuda', RANK) |
LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') |
os.environ['NCCL_BLOCKING_WAIT'] = '1' |
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', |
timeout=timedelta(seconds=3600), |
rank=RANK, |
world_size=world_size) |
def _setup_train(self, world_size): |
""" |
Builds dataloaders and optimizer on correct rank process. |
""" |
self.run_callbacks('on_pretrain_routine_start') |
ckpt = self.setup_model() |
self.model = self.model.to(self.device) |
self.set_model_attributes() |
self.amp = torch.tensor(self.args.amp).to(self.device) |
if self.amp and RANK in (-1, 0): |
callbacks_backup = callbacks.default_callbacks.copy() |
self.amp = torch.tensor(check_amp(self.model), device=self.device) |
callbacks.default_callbacks = callbacks_backup |
if RANK > -1: |
dist.broadcast(self.amp, src=0) |
self.amp = bool(self.amp) |
self.scaler = amp.GradScaler(enabled=self.amp) |
if world_size > 1: |
self.model = DDP(self.model, device_ids=[RANK]) |
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) |
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1) |
if self.batch_size == -1: |
if RANK == -1: |
self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp) |
else: |
SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. ' |
'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16') |
batch_size = self.batch_size // max(world_size, 1) |
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train') |
if RANK in (-1, 0): |
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val') |
self.validator = self.get_validator() |
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val') |
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) |
self.ema = ModelEMA(self.model) |
if self.args.plots and not self.args.v5loader: |
self.plot_training_labels() |
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) |
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs |
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs |
self.optimizer = self.build_optimizer(model=self.model, |
name=self.args.optimizer, |
lr=self.args.lr0, |
momentum=self.args.momentum, |
decay=weight_decay, |
iterations=iterations) |
if self.args.cos_lr: |
self.lf = one_cycle(1, self.args.lrf, self.epochs) |
else: |
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf |
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) |
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False |
self.resume_training(ckpt) |
self.scheduler.last_epoch = self.start_epoch - 1 |
self.run_callbacks('on_pretrain_routine_end') |
def _do_train(self, world_size=1): |
"""Train completed, evaluate and plot if specified by arguments.""" |
if world_size > 1: |
self._setup_ddp(world_size) |
self._setup_train(world_size) |
self.epoch_time = None |
self.epoch_time_start = time.time() |
self.train_time_start = time.time() |
nb = len(self.train_loader) |
nw = max(round(self.args.warmup_epochs * nb), 100) |
last_opt_step = -1 |
self.run_callbacks('on_train_start') |
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' |
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' |
f"Logging results to {colorstr('bold', self.save_dir)}\n" |
f'Starting training for {self.epochs} epochs...') |
if self.args.close_mosaic: |
base_idx = (self.epochs - self.args.close_mosaic) * nb |
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) |
epoch = self.epochs |
for epoch in range(self.start_epoch, self.epochs): |
self.epoch = epoch |
self.run_callbacks('on_train_epoch_start') |
self.model.train() |
if RANK != -1: |
self.train_loader.sampler.set_epoch(epoch) |
pbar = enumerate(self.train_loader) |
if epoch == (self.epochs - self.args.close_mosaic): |
LOGGER.info('Closing dataloader mosaic') |
if hasattr(self.train_loader.dataset, 'mosaic'): |
self.train_loader.dataset.mosaic = False |
if hasattr(self.train_loader.dataset, 'close_mosaic'): |
self.train_loader.dataset.close_mosaic(hyp=self.args) |
self.train_loader.reset() |
if RANK in (-1, 0): |
LOGGER.info(self.progress_string()) |
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT) |
self.tloss = None |
self.optimizer.zero_grad() |
for i, batch in pbar: |
self.run_callbacks('on_train_batch_start') |
ni = i + nb * epoch |
if ni <= nw: |
xi = [0, nw] |
self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()) |
for j, x in enumerate(self.optimizer.param_groups): |
x['lr'] = np.interp( |
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)]) |
if 'momentum' in x: |
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) |
with torch.cuda.amp.autocast(self.amp): |
batch = self.preprocess_batch(batch) |
self.loss, self.loss_items = self.model(batch) |
if RANK != -1: |
self.loss *= world_size |
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ |
else self.loss_items |
self.scaler.scale(self.loss).backward() |
if ni - last_opt_step >= self.accumulate: |
self.optimizer_step() |
last_opt_step = ni |
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' |
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1 |
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) |
if RANK in (-1, 0): |
pbar.set_description( |
('%11s' * 2 + '%11.4g' * (2 + loss_len)) % |
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1])) |
self.run_callbacks('on_batch_end') |
if self.args.plots and ni in self.plot_idx: |
self.plot_training_samples(batch, ni) |
self.run_callbacks('on_train_batch_end') |
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} |
self.scheduler.step() |
self.run_callbacks('on_train_epoch_end') |
if RANK in (-1, 0): |
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) |
final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop |
if self.args.val or final_epoch: |
self.metrics, self.fitness = self.validate() |
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr}) |
self.stop = self.stopper(epoch + 1, self.fitness) |
if self.args.save or (epoch + 1 == self.epochs): |
self.save_model() |
self.run_callbacks('on_model_save') |
tnow = time.time() |
self.epoch_time = tnow - self.epoch_time_start |
self.epoch_time_start = tnow |
self.run_callbacks('on_fit_epoch_end') |
torch.cuda.empty_cache() |
if RANK != -1: |
broadcast_list = [self.stop if RANK == 0 else None] |
dist.broadcast_object_list(broadcast_list, 0) |
if RANK != 0: |
self.stop = broadcast_list[0] |
if self.stop: |
break |
if RANK in (-1, 0): |
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in ' |
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.') |
self.final_eval() |
if self.args.plots: |
self.plot_metrics() |
self.run_callbacks('on_train_end') |
torch.cuda.empty_cache() |
self.run_callbacks('teardown') |
def save_model(self): |
"""Save model checkpoints based on various conditions.""" |
ckpt = { |
'epoch': self.epoch, |
'best_fitness': self.best_fitness, |
'model': deepcopy(de_parallel(self.model)).half(), |
'ema': deepcopy(self.ema.ema).half(), |
'updates': self.ema.updates, |
'optimizer': self.optimizer.state_dict(), |
'train_args': vars(self.args), |
'date': datetime.now().isoformat(), |
'version': __version__} |
try: |
import dill as pickle |
except ImportError: |
import pickle |
torch.save(ckpt, self.last, pickle_module=pickle) |
if self.best_fitness == self.fitness: |
torch.save(ckpt, self.best, pickle_module=pickle) |
if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0): |
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle) |
del ckpt |
@staticmethod |
def get_dataset(data): |
""" |
Get train, val path from data dict if it exists. Returns None if data format is not recognized. |
""" |
return data['train'], data.get('val') or data.get('test') |
def setup_model(self): |
""" |
load/create/download model for any task. |
""" |
if isinstance(self.model, torch.nn.Module): |
return |
model, weights = self.model, None |
ckpt = None |
if str(model).endswith('.pt'): |
weights, ckpt = attempt_load_one_weight(model) |
cfg = ckpt['model'].yaml |
else: |
cfg = model |
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) |
return ckpt |
def optimizer_step(self): |
"""Perform a single step of the training optimizer with gradient clipping and EMA update.""" |
self.scaler.unscale_(self.optimizer) |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) |
self.scaler.step(self.optimizer) |
self.scaler.update() |
self.optimizer.zero_grad() |
if self.ema: |
self.ema.update(self.model) |
def preprocess_batch(self, batch): |
""" |
Allows custom preprocessing model inputs and ground truths depending on task type. |
""" |
return batch |
def validate(self): |
""" |
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key. |
""" |
metrics = self.validator(self) |
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) |
if not self.best_fitness or self.best_fitness < fitness: |
self.best_fitness = fitness |
return metrics, fitness |
def get_model(self, cfg=None, weights=None, verbose=True): |
"""Get model and raise NotImplementedError for loading cfg files.""" |
raise NotImplementedError("This task trainer doesn't support loading cfg files") |
def get_validator(self): |
"""Returns a NotImplementedError when the get_validator function is called.""" |
raise NotImplementedError('get_validator function not implemented in trainer') |
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): |
""" |
Returns dataloader derived from torch.data.Dataloader. |
""" |
raise NotImplementedError('get_dataloader function not implemented in trainer') |
def build_dataset(self, img_path, mode='train', batch=None): |
"""Build dataset""" |
raise NotImplementedError('build_dataset function not implemented in trainer') |
def label_loss_items(self, loss_items=None, prefix='train'): |
""" |
Returns a loss dict with labelled training loss items tensor |
""" |
return {'loss': loss_items} if loss_items is not None else ['loss'] |
def set_model_attributes(self): |
""" |
To set or update model parameters before training. |
""" |
self.model.names = self.data['names'] |
def build_targets(self, preds, targets): |
"""Builds target tensors for training YOLO model.""" |
pass |
def progress_string(self): |
"""Returns a string describing training progress.""" |
return '' |
def plot_training_samples(self, batch, ni): |
"""Plots training samples during YOLOv5 training.""" |
pass |
def plot_training_labels(self): |
"""Plots training labels for YOLO model.""" |
pass |
def save_metrics(self, metrics): |
"""Saves training metrics to a CSV file.""" |
keys, vals = list(metrics.keys()), list(metrics.values()) |
n = len(metrics) + 1 |
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') |
with open(self.csv, 'a') as f: |
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n') |
def plot_metrics(self): |
"""Plot and display metrics visually.""" |
pass |
def on_plot(self, name, data=None): |
"""Registers plots (e.g. to be consumed in callbacks)""" |
self.plots[name] = {'data': data, 'timestamp': time.time()} |
def final_eval(self): |
"""Performs final evaluation and validation for object detection YOLO model.""" |
for f in self.last, self.best: |
if f.exists(): |
strip_optimizer(f) |
if f is self.best: |
LOGGER.info(f'\nValidating {f}...') |
self.metrics = self.validator(model=f) |
self.metrics.pop('fitness', None) |
self.run_callbacks('on_fit_epoch_end') |
def check_resume(self): |
"""Check if resume checkpoint exists and update arguments accordingly.""" |
resume = self.args.resume |
if resume: |
try: |
exists = isinstance(resume, (str, Path)) and Path(resume).exists() |
last = Path(check_file(resume) if exists else get_latest_run()) |
ckpt_args = attempt_load_weights(last).args |
if not Path(ckpt_args['data']).exists(): |
ckpt_args['data'] = self.args.data |
self.args = get_cfg(ckpt_args) |
self.args.model, resume = str(last), True |
except Exception as e: |
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, ' |
"i.e. 'yolo train resume model=path/to/last.pt'") from e |
self.resume = resume |
def resume_training(self, ckpt): |
"""Resume YOLO training from given epoch and best fitness.""" |
if ckpt is None: |
return |
best_fitness = 0.0 |
start_epoch = ckpt['epoch'] + 1 |
if ckpt['optimizer'] is not None: |
self.optimizer.load_state_dict(ckpt['optimizer']) |
best_fitness = ckpt['best_fitness'] |
if self.ema and ckpt.get('ema'): |
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) |
self.ema.updates = ckpt['updates'] |
if self.resume: |
assert start_epoch > 0, \ |
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ |
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" |
LOGGER.info( |
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs') |
if self.epochs < start_epoch: |
LOGGER.info( |
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.") |
self.epochs += ckpt['epoch'] |
self.best_fitness = best_fitness |
self.start_epoch = start_epoch |
if start_epoch > (self.epochs - self.args.close_mosaic): |
LOGGER.info('Closing dataloader mosaic') |
if hasattr(self.train_loader.dataset, 'mosaic'): |
self.train_loader.dataset.mosaic = False |
if hasattr(self.train_loader.dataset, 'close_mosaic'): |
self.train_loader.dataset.close_mosaic(hyp=self.args) |
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): |
""" |
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, |
momentum, weight decay, and number of iterations. |
Args: |
model (torch.nn.Module): The model for which to build an optimizer. |
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected |
based on the number of iterations. Default: 'auto'. |
lr (float, optional): The learning rate for the optimizer. Default: 0.001. |
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9. |
decay (float, optional): The weight decay for the optimizer. Default: 1e-5. |
iterations (float, optional): The number of iterations, which determines the optimizer if |
name is 'auto'. Default: 1e5. |
Returns: |
(torch.optim.Optimizer): The constructed optimizer. |
""" |
g = [], [], [] |
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) |
if name == 'auto': |
nc = getattr(model, 'nc', 10) |
lr_fit = round(0.002 * 5 / (4 + nc), 6) |
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9) |
self.args.warmup_bias_lr = 0.0 |
for module_name, module in model.named_modules(): |
for param_name, param in module.named_parameters(recurse=False): |
fullname = f'{module_name}.{param_name}' if module_name else param_name |
if 'bias' in fullname: |
g[2].append(param) |
elif isinstance(module, bn): |
g[1].append(param) |
else: |
g[0].append(param) |
if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'): |
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) |
elif name == 'RMSProp': |
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) |
elif name == 'SGD': |
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) |
else: |
raise NotImplementedError( |
f"Optimizer '{name}' not found in list of available optimizers " |
f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].' |
'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.') |
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) |
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) |
LOGGER.info( |
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups " |
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)') |
return optimizer |