|
from pathlib import Path |
|
import numpy as np |
|
np.random.seed(123) |
|
|
|
import ultralytics |
|
ultralytics.checks() |
|
from ultralytics import YOLO |
|
|
|
|
|
from typing import Union |
|
|
|
from ultralytics import yolo |
|
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel, |
|
attempt_load_one_weight) |
|
from ultralytics.yolo.utils import (LOGGER, RANK, yaml_load) |
|
from ultralytics.yolo.utils.checks import check_pip_update_available, check_yaml |
|
|
|
TASK_MAP = { |
|
'classify': [ |
|
ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator, |
|
yolo.v8.classify.ClassificationPredictor], |
|
'detect': [ |
|
DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator, |
|
yolo.v8.detect.DetectionPredictor], |
|
'segment': [ |
|
SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator, |
|
yolo.v8.segment.SegmentationPredictor], |
|
'pose': [PoseModel, yolo.v8.pose.PoseTrainer, yolo.v8.pose.PoseValidator, yolo.v8.pose.PosePredictor]} |
|
|
|
|
|
class YOLO_custom(YOLO): |
|
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None: |
|
super().__init__(model, task) |
|
|
|
def train(self, hyp: dict = None, **kwargs): |
|
""" |
|
CAUTION: OVERWRITES THE ORIGINAL METHOD TO ACCEPT HYPERPARAMETERS |
|
|
|
Trains the model on a given dataset. |
|
|
|
Args: |
|
**kwargs (Any): Any number of arguments representing the training configuration. |
|
""" |
|
self._check_is_pytorch_model() |
|
if self.session: |
|
if any(kwargs): |
|
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.') |
|
kwargs = self.session.train_args |
|
self.session.check_disk_space() |
|
check_pip_update_available() |
|
overrides = self.overrides.copy() |
|
overrides.update(kwargs) |
|
if kwargs.get('cfg'): |
|
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.") |
|
overrides = yaml_load(check_yaml(kwargs['cfg'])) |
|
|
|
|
|
if hyp: |
|
if isinstance(hyp, dict): |
|
LOGGER.info(f"'hyp' dict passed -> overriding the hyperparameters found in 'hyp'.") |
|
for k, v in list(hyp.items()): |
|
if v is None: |
|
del hyp[k] |
|
overrides.update(hyp) |
|
else: |
|
LOGGER.warning(f"WARNING the 'hyp' variable MUST be a dict") |
|
|
|
|
|
overrides['mode'] = 'train' |
|
if not overrides.get('data'): |
|
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'") |
|
if overrides.get('resume'): |
|
overrides['resume'] = self.ckpt_path |
|
self.task = overrides.get('task') or self.task |
|
self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks) |
|
if not overrides.get('resume'): |
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) |
|
self.model = self.trainer.model |
|
self.trainer.hub_session = self.session |
|
self.trainer.train() |
|
|
|
if RANK in (-1, 0): |
|
self.model, _ = attempt_load_one_weight(str(self.trainer.best)) |
|
self.overrides = self.model.args |
|
self.metrics = getattr(self.trainer.validator, 'metrics', None) |
|
|
|
if __name__ == "__main__": |
|
pass |
|
|