File size: 1,633 Bytes
605bd6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from ultralytics import YOLO
# Define a class for training and validating a YOLO model
class YOLOTrainer:
def __init__(self, model_config, data_config, batch_size, img_size, epochs, patience):
# Initialize the YOLO model with the given configuration
self.model = YOLO(model_config)
self.data_config = data_config
self.batch_size = batch_size
self.img_size = img_size
self.epochs = epochs
self.patience = patience
# Method to train the model
def train(self):
self.model.train(data=self.data_config, batch=self.batch_size, imgsz=self.img_size, epochs=self.epochs, patience=self.patience)
# Method to validate the model
def validate(self):
self.model.val()
# Check if the script is run directly (not imported as a module)
if __name__ == "__main__":
# Define the configuration for the model
model_config = 'yolov8m.yaml'
# Define the data configuration
data_config = 'dataset/data.yaml'
# Define the batch size for training
batch_size = 16
# Define the image size for training
img_size = 640
# Define the number of epochs for training
epochs = 100
# Define the patience for early stopping
patience = 20
# Create a YOLOTrainer object with the specified configurations
trainer = YOLOTrainer(model_config, data_config, batch_size, img_size, epochs, patience)
# Train the model
trainer.train()
# Validate the model
trainer.validate()
# Optional: Save the best model to a file
trainer.model.save('model/best_model.pt')
|