File size: 1,633 Bytes
605bd6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from ultralytics import YOLO

# Define a class for training and validating a YOLO model
class YOLOTrainer:
    def __init__(self, model_config, data_config, batch_size, img_size, epochs, patience):
        # Initialize the YOLO model with the given configuration
        self.model = YOLO(model_config)
        self.data_config = data_config
        self.batch_size = batch_size
        self.img_size = img_size
        self.epochs = epochs
        self.patience = patience

    # Method to train the model
    def train(self):
        self.model.train(data=self.data_config, batch=self.batch_size, imgsz=self.img_size, epochs=self.epochs, patience=self.patience)

    # Method to validate the model
    def validate(self):
        self.model.val()

# Check if the script is run directly (not imported as a module)
if __name__ == "__main__":
    # Define the configuration for the model
    model_config = 'yolov8m.yaml'
    # Define the data configuration
    data_config = 'dataset/data.yaml'
    # Define the batch size for training
    batch_size = 16
    # Define the image size for training
    img_size = 640
    # Define the number of epochs for training
    epochs = 100
    # Define the patience for early stopping
    patience = 20

    # Create a YOLOTrainer object with the specified configurations
    trainer = YOLOTrainer(model_config, data_config, batch_size, img_size, epochs, patience)
    # Train the model
    trainer.train()
    # Validate the model
    trainer.validate()

    # Optional: Save the best model to a file
    trainer.model.save('model/best_model.pt')