from pathlib import Path import numpy as np np.random.seed(123) import ultralytics ultralytics.checks() from ultralytics import YOLO # imports for the YOLO custom class from typing import Union from ultralytics import yolo # noqa from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel, attempt_load_one_weight) from ultralytics.yolo.utils import (LOGGER, RANK, yaml_load) from ultralytics.yolo.utils.checks import check_pip_update_available, check_yaml TASK_MAP = { 'classify': [ ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator, yolo.v8.classify.ClassificationPredictor], 'detect': [ DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator, yolo.v8.detect.DetectionPredictor], 'segment': [ SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator, yolo.v8.segment.SegmentationPredictor], 'pose': [PoseModel, yolo.v8.pose.PoseTrainer, yolo.v8.pose.PoseValidator, yolo.v8.pose.PosePredictor]} # /imports for the YOLO custom class class YOLO_custom(YOLO): def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None: super().__init__(model, task) def train(self, hyp: dict = None, **kwargs): """ CAUTION: OVERWRITES THE ORIGINAL METHOD TO ACCEPT HYPERPARAMETERS Trains the model on a given dataset. Args: **kwargs (Any): Any number of arguments representing the training configuration. """ self._check_is_pytorch_model() if self.session: # Ultralytics HUB session if any(kwargs): LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.') kwargs = self.session.train_args self.session.check_disk_space() check_pip_update_available() overrides = self.overrides.copy() overrides.update(kwargs) if kwargs.get('cfg'): LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.") overrides = yaml_load(check_yaml(kwargs['cfg'])) #********************** update hyp start ********************** if hyp: if isinstance(hyp, dict): LOGGER.info(f"'hyp' dict passed -> overriding the hyperparameters found in 'hyp'.") for k, v in list(hyp.items()): if v is None: del hyp[k] overrides.update(hyp) else: LOGGER.warning(f"WARNING the 'hyp' variable MUST be a dict") #********************** update hyp end ********************** overrides['mode'] = 'train' if not overrides.get('data'): raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'") if overrides.get('resume'): overrides['resume'] = self.ckpt_path self.task = overrides.get('task') or self.task self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks) if not overrides.get('resume'): # manually set model only if not resuming self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) self.model = self.trainer.model self.trainer.hub_session = self.session # attach optional HUB session self.trainer.train() # update model and cfg after training if RANK in (-1, 0): self.model, _ = attempt_load_one_weight(str(self.trainer.best)) self.overrides = self.model.args self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP if __name__ == "__main__": pass