kothariyashhh's picture
Upload 6 files
605bd6c verified
from ultralytics import YOLO
# Define a class for training and validating a YOLO model
class YOLOTrainer:
def __init__(self, model_config, data_config, batch_size, img_size, epochs, patience):
# Initialize the YOLO model with the given configuration
self.model = YOLO(model_config)
self.data_config = data_config
self.batch_size = batch_size
self.img_size = img_size
self.epochs = epochs
self.patience = patience
# Method to train the model
def train(self):
self.model.train(data=self.data_config, batch=self.batch_size, imgsz=self.img_size, epochs=self.epochs, patience=self.patience)
# Method to validate the model
def validate(self):
self.model.val()
# Check if the script is run directly (not imported as a module)
if __name__ == "__main__":
# Define the configuration for the model
model_config = 'yolov8m.yaml'
# Define the data configuration
data_config = 'dataset/data.yaml'
# Define the batch size for training
batch_size = 16
# Define the image size for training
img_size = 640
# Define the number of epochs for training
epochs = 100
# Define the patience for early stopping
patience = 20
# Create a YOLOTrainer object with the specified configurations
trainer = YOLOTrainer(model_config, data_config, batch_size, img_size, epochs, patience)
# Train the model
trainer.train()
# Validate the model
trainer.validate()
# Optional: Save the best model to a file
trainer.model.save('model/best_model.pt')