|
from ultralytics import YOLO
|
|
|
|
|
|
class YOLOTrainer:
|
|
def __init__(self, model_config, data_config, batch_size, img_size, epochs, patience):
|
|
|
|
self.model = YOLO(model_config)
|
|
self.data_config = data_config
|
|
self.batch_size = batch_size
|
|
self.img_size = img_size
|
|
self.epochs = epochs
|
|
self.patience = patience
|
|
|
|
|
|
def train(self):
|
|
self.model.train(data=self.data_config, batch=self.batch_size, imgsz=self.img_size, epochs=self.epochs, patience=self.patience)
|
|
|
|
|
|
def validate(self):
|
|
self.model.val()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
model_config = 'yolov8m.yaml'
|
|
|
|
data_config = 'dataset/data.yaml'
|
|
|
|
batch_size = 16
|
|
|
|
img_size = 640
|
|
|
|
epochs = 100
|
|
|
|
patience = 20
|
|
|
|
|
|
trainer = YOLOTrainer(model_config, data_config, batch_size, img_size, epochs, patience)
|
|
|
|
trainer.train()
|
|
|
|
trainer.validate()
|
|
|
|
|
|
trainer.model.save('model/best_model.pt')
|
|
|