DCVAI-Example-1 / script.py
datascienceharp's picture
updates
a3be692
raw
history blame
2.72 kB
"""
This script is used to train the model for the project.
You should import your main functions from the data_curation.py script and use them to prepare the dataset for training.
The approved model is `yolov8m` from Ulytralytics.
Your predictions must be in a label_field called "predictions" in the dataset.
See here for more details about hyperparameters for this model: https://docs.ultralytics.com/modes/train/#train-settings
"""
import os
from datetime import datetime
from math import log
import yaml
import fiftyone as fo
import fiftyone.utils.random as four
import fiftyone.utils.huggingface as fouh
from ultralytics import YOLO
from data_curation import prepare_dataset
def export_to_yolo_format(
samples,
classes,
label_field="ground_truth",
export_dir="./yolo_formatted",
splits=["train", "val"]
):
"""
Export samples to YOLO format, optionally handling multiple data splits.
Args:
samples (fiftyone.core.collections.SampleCollection): The dataset or samples to export.
export_dir (str): The directory where the exported data will be saved.
classes (list): A list of class names for the YOLO format.
label_field (str, optional): The field in the samples that contains the labels.
Defaults to "ground_truth".
splits (str, list, optional): The split(s) to export. Can be a single split name (str)
or a list of split names. If None, all samples are exported as "val" split.
Defaults to None.
Returns:
None
"""
if splits is None:
splits = ["val"]
elif isinstance(splits, str):
splits = [splits]
for split in splits:
split_view = samples if split == "val" and splits == ["val"] else samples.match_tags(split)
split_view.export(
export_dir=export_dir,
dataset_type=fo.types.YOLOv5Dataset,
label_field=label_field,
classes=classes,
split=split
)
def train_model(training_dataset, training_config):
"""
Train the YOLO model on the given dataset using the provided configuration.
"""
four.random_split(training_dataset, {"train": training_config['train_split'], "val": training_config['val_split']})
export_to_yolo_format(
samples=training_dataset,
classes=training_dataset.default_classes,
)
model = YOLO("yolov10m.pt")
results = model.train(
data="./yolo_formatted/dataset.yaml",
**training_config['train_params']
)
best_model_path = str(results.save_dir / "weights/best.pt")
best_model = YOLO(best_model_path)
return best_model
if __name__=="__main__":
run()