File size: 3,511 Bytes
bdcc619
e3c4fa8
 
 
bdcc619
 
 
e3c4fa8
bdcc619
 
 
 
 
 
d650f60
bdcc619
 
 
 
 
 
 
 
 
 
 
 
8b84a69
bdcc619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a04535e
bdcc619
 
 
a04535e
 
 
 
 
 
 
 
 
d650f60
 
 
 
e3c4fa8
 
bdcc619
e3c4fa8
bdcc619
e3c4fa8
bdcc619
 
 
 
e3c4fa8
bdcc619
e3c4fa8
a3be692
e3c4fa8
bdcc619
e3c4fa8
bdcc619
8b84a69
bdcc619
 
e3c4fa8
 
bdcc619
e3c4fa8
bdcc619
e3c4fa8
bdcc619
e3c4fa8
bdcc619
e977585
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Note: You don't need to modify this file as this script is used to train the model for the project.

All of your work should be done in the data_curation.py script.

You should import your main functions from the data_curation.py script and use them to prepare the dataset for training.

The approved model is `yolov10m` from Ulytralytics. 

Your predictions must be in a label_field called "predictions" in the dataset.

See here for more details about hyperparameters for this model: https://docs.ultralytics.com/modes/train/#train-settings
"""
import os

import yaml

import fiftyone as fo
import fiftyone.utils.random as four
import fiftyone.utils.huggingface as fouh

from data_curation import prepare_dataset

def export_to_yolo_format(
    samples,
    classes,
    label_field="ground_truth",
    export_dir=".",
    splits=["train", "val"]
):
    """
    Export samples to YOLO format, optionally handling multiple data splits.

    Args:
        samples (fiftyone.core.collections.SampleCollection): The dataset or samples to export.
        export_dir (str): The directory where the exported data will be saved.
        classes (list): A list of class names for the YOLO format.
        label_field (str, optional): The field in the samples that contains the labels.
            Defaults to "ground_truth".
        splits (str, list, optional): The split(s) to export. Can be a single split name (str) 
            or a list of split names. If None, all samples are exported as "val" split. 
            Defaults to None.

    Returns:
        None

    """
    if splits is None:
        splits = ["val"]
    elif isinstance(splits, str):
        splits = [splits]

    for split in splits:
        split_view = samples if split == "val" and splits == ["val"] else samples.match_tags(split)
        
        split_view.export(
            export_dir=export_dir,
            dataset_type=fo.types.YOLOv5Dataset,
            label_field=label_field,
            classes=classes,
            split=split
        )

def train_model():
    """
    Train the YOLO model on the given dataset using the provided configuration.
    """

    script_dir = os.path.dirname(os.path.abspath(__file__))

    config_path = os.path.join(script_dir, 'training_config.yaml')

    with open(config_path, 'r') as file:
        training_config = yaml.safe_load(file)


    training_dataset = fouh.load_from_hub(
        "Voxel51/Data-Centric-Visual-AI-Challenge-Train-Set",
        max_samples=100 #for testing remove this later
        )

    print("Splitting the dataset...")
    four.random_split(training_dataset, {"train": training_config['train_split'], "val": training_config['val_split']})
    print("Dataset split completed.")

    print("Exporting dataset to YOLO format...")
    export_to_yolo_format(
        samples=training_dataset,
        classes=training_dataset.default_classes,
    )
    print("Dataset export completed.")

    print("Initializing the YOLO model...")
    model = YOLO("yolov10m.pt")
    print("Model initialized.")

    print("Starting model training...")
    results = model.train(
        data="dataset.yaml",
        **training_config['train_params']
    )
    print("Model training completed.")

    best_model_path = str(results.save_dir / "weights/best.pt")
    print(f"Best model path: {best_model_path}")
    best_model = YOLO(best_model_path)
    print("Best model loaded.")

    print(f"Best model saved to: {best_model_path}")
if __name__=="__main__":
    train_model()