|
""" |
|
This script is used to train the model for the project. |
|
|
|
You should import your main functions from the data_curation.py script and use them to prepare the dataset for training. |
|
|
|
The approved model is `yolov8m` from Ulytralytics. |
|
|
|
Your predictions must be in a label_field called "predictions" in the dataset. |
|
|
|
See here for more details about hyperparameters for this model: https://docs.ultralytics.com/modes/train/#train-settings |
|
|
|
""" |
|
import os |
|
from datetime import datetime |
|
from math import log |
|
import yaml |
|
|
|
import fiftyone as fo |
|
import fiftyone.utils.random as four |
|
import fiftyone.utils.huggingface as fouh |
|
|
|
from ultralytics import YOLO |
|
|
|
from data_curation import prepare_dataset |
|
|
|
def export_to_yolo_format( |
|
samples, |
|
classes, |
|
label_field="ground_truth", |
|
export_dir="./yolo_formatted", |
|
splits=["train", "val"] |
|
): |
|
""" |
|
Export samples to YOLO format, optionally handling multiple data splits. |
|
|
|
Args: |
|
samples (fiftyone.core.collections.SampleCollection): The dataset or samples to export. |
|
export_dir (str): The directory where the exported data will be saved. |
|
classes (list): A list of class names for the YOLO format. |
|
label_field (str, optional): The field in the samples that contains the labels. |
|
Defaults to "ground_truth". |
|
splits (str, list, optional): The split(s) to export. Can be a single split name (str) |
|
or a list of split names. If None, all samples are exported as "val" split. |
|
Defaults to None. |
|
|
|
Returns: |
|
None |
|
|
|
""" |
|
if splits is None: |
|
splits = ["val"] |
|
elif isinstance(splits, str): |
|
splits = [splits] |
|
|
|
for split in splits: |
|
split_view = samples if split == "val" and splits == ["val"] else samples.match_tags(split) |
|
|
|
split_view.export( |
|
export_dir=export_dir, |
|
dataset_type=fo.types.YOLOv5Dataset, |
|
label_field=label_field, |
|
classes=classes, |
|
split=split |
|
) |
|
|
|
def train_model(training_dataset, training_config): |
|
""" |
|
Train the YOLO model on the given dataset using the provided configuration. |
|
""" |
|
four.random_split(training_dataset, {"train": training_config['train_split'], "val": training_config['val_split']}) |
|
|
|
export_to_yolo_format( |
|
samples=training_dataset, |
|
classes=training_dataset.default_classes, |
|
) |
|
|
|
model = YOLO("yolov10m.pt") |
|
|
|
results = model.train( |
|
data="./yolo_formatted/dataset.yaml", |
|
**training_config['train_params'] |
|
) |
|
|
|
best_model_path = str(results.save_dir / "weights/best.pt") |
|
best_model = YOLO(best_model_path) |
|
|
|
return best_model |
|
|
|
|
|
|
|
if __name__=="__main__": |
|
run() |
|
|