Upload 3 files
Browse files- data_curation.py +61 -0
- script.py +168 -0
- training_config.yaml +11 -0
data_curation.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script is used to curate the data for the project.
|
3 |
+
|
4 |
+
Implement your functions to to clean the data and prepare it for model training.
|
5 |
+
|
6 |
+
Note: the competition requires that you use FiftyOne for data curation and you are only allowed to
|
7 |
+
use the approaved dataset from the hub, Voxel51/Data-Centric-Visual-AI-Challenge-Train-Set, which can
|
8 |
+
be found here: https://huggingface.co/datasets/Voxel51/Data-Centric-Visual-AI-Challenge-Train-Set
|
9 |
+
"""
|
10 |
+
|
11 |
+
import fiftyone as fo
|
12 |
+
import fiftyone.utils.huggingface as fouh
|
13 |
+
|
14 |
+
# Implement functions for data curation. below are just dummy functions as examples
|
15 |
+
|
16 |
+
def shuffle_data(dataset):
|
17 |
+
"""Shuffle the dataset"""
|
18 |
+
return dataset.shuffle(seed=51)
|
19 |
+
|
20 |
+
def take_random_sample(dataset):
|
21 |
+
"""Take a sample from the dataset"""
|
22 |
+
return dataset.take(size=10,seed=51)
|
23 |
+
|
24 |
+
def prepare_dataset(name):
|
25 |
+
"""
|
26 |
+
Prepare the dataset for model training.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
name (str): The name of the dataset to load. Must be "Voxel51/Data-Centric-Visual-AI-Challenge-Train-Set".
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
fiftyone.core.dataset.Dataset: The curated dataset.
|
33 |
+
|
34 |
+
Raises:
|
35 |
+
ValueError: If the provided dataset name is not the approved one.
|
36 |
+
|
37 |
+
Note:
|
38 |
+
The following code block MUST NOT be removed from your submission:
|
39 |
+
|
40 |
+
APPROVED_DATASET = "Voxel51/Data-Centric-Visual-AI-Challenge-Train-Set"
|
41 |
+
|
42 |
+
if name != APPROVED_DATASET:
|
43 |
+
raise ValueError(f"Only the approved dataset '{APPROVED_DATASET}' is allowed for this competition.")
|
44 |
+
|
45 |
+
This ensures that only the approved dataset is used for the competition.
|
46 |
+
"""
|
47 |
+
APPROVED_DATASET = "Voxel51/Data-Centric-Visual-AI-Challenge-Train-Set"
|
48 |
+
|
49 |
+
if name != APPROVED_DATASET:
|
50 |
+
raise ValueError(f"Only the approved dataset '{APPROVED_DATASET}' is allowed for this competition.")
|
51 |
+
|
52 |
+
# Load the approved dataset from the hub
|
53 |
+
dataset = fouh.load_from_hub(name, split="train")
|
54 |
+
|
55 |
+
# Implement your data curation functions here
|
56 |
+
dataset = shuffle_data(dataset)
|
57 |
+
dataset = take_random_sample(dataset)
|
58 |
+
|
59 |
+
# Return the curated dataset
|
60 |
+
curated_dataset = dataset.clone()
|
61 |
+
return curated_dataset
|
script.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script is used to train the model for the project.
|
3 |
+
|
4 |
+
You should import your main functions from the data_curation.py script and use them to prepare the dataset for training.
|
5 |
+
|
6 |
+
The approved model is `yolov8m` from Ulytralytics.
|
7 |
+
|
8 |
+
Your predictions must be in a label_field called "predictions" in the dataset.
|
9 |
+
|
10 |
+
You may pass your final selection of hyperparameters as keyword arguments in the load_zoo_model function.
|
11 |
+
|
12 |
+
See here for more details about hyperparameters for this model: https://docs.ultralytics.com/modes/train/#train-settings
|
13 |
+
|
14 |
+
"""
|
15 |
+
import os
|
16 |
+
from datetime import datetime
|
17 |
+
from math import log
|
18 |
+
import yaml
|
19 |
+
|
20 |
+
import fiftyone as fo
|
21 |
+
import fiftyone.utils.random as four
|
22 |
+
import fiftyone.utils.huggingface as fouh
|
23 |
+
|
24 |
+
from ultralytics import YOLO
|
25 |
+
|
26 |
+
from data_curation import prepare_dataset
|
27 |
+
|
28 |
+
def export_to_yolo_format(
|
29 |
+
samples,
|
30 |
+
classes,
|
31 |
+
label_field="ground_truth",
|
32 |
+
export_dir="./yolo_formatted",
|
33 |
+
splits=["train", "val"]
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Export samples to YOLO format, optionally handling multiple data splits.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
samples (fiftyone.core.collections.SampleCollection): The dataset or samples to export.
|
40 |
+
export_dir (str): The directory where the exported data will be saved.
|
41 |
+
classes (list): A list of class names for the YOLO format.
|
42 |
+
label_field (str, optional): The field in the samples that contains the labels.
|
43 |
+
Defaults to "ground_truth".
|
44 |
+
splits (str, list, optional): The split(s) to export. Can be a single split name (str)
|
45 |
+
or a list of split names. If None, all samples are exported as "val" split.
|
46 |
+
Defaults to None.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
None
|
50 |
+
|
51 |
+
"""
|
52 |
+
if splits is None:
|
53 |
+
splits = ["val"]
|
54 |
+
elif isinstance(splits, str):
|
55 |
+
splits = [splits]
|
56 |
+
|
57 |
+
for split in splits:
|
58 |
+
split_view = samples if split == "val" and splits == ["val"] else samples.match_tags(split)
|
59 |
+
|
60 |
+
split_view.export(
|
61 |
+
export_dir=export_dir,
|
62 |
+
dataset_type=fo.types.YOLOv5Dataset,
|
63 |
+
label_field=label_field,
|
64 |
+
classes=classes,
|
65 |
+
split=split
|
66 |
+
)
|
67 |
+
|
68 |
+
def train_model(training_dataset, training_config):
|
69 |
+
"""
|
70 |
+
Train the YOLO model on the given dataset using the provided configuration.
|
71 |
+
"""
|
72 |
+
four.random_split(training_dataset, {"train": training_config['train_split'], "val": training_config['val_split']})
|
73 |
+
|
74 |
+
export_to_yolo_format(
|
75 |
+
samples=training_dataset,
|
76 |
+
classes=training_dataset.default_classes,
|
77 |
+
)
|
78 |
+
|
79 |
+
model = YOLO("yolov8m.pt")
|
80 |
+
|
81 |
+
# Check if epochs in train_params exceeds 50
|
82 |
+
if 'epochs' in training_config['train_params'] and training_config['train_params']['epochs'] > 50:
|
83 |
+
raise ValueError("Number of epochs cannot exceed 50. Please adjust the 'epochs' parameter in your training configuration.")
|
84 |
+
|
85 |
+
results = model.train(
|
86 |
+
data="./yolo_formatted/dataset.yaml",
|
87 |
+
**training_config['train_params']
|
88 |
+
)
|
89 |
+
|
90 |
+
best_model_path = str(results.save_dir / "weights/best.pt")
|
91 |
+
best_model = YOLO(best_model_path)
|
92 |
+
|
93 |
+
return best_model
|
94 |
+
|
95 |
+
|
96 |
+
def run_inference_on_eval_set(eval_dataset, best_model):
|
97 |
+
"""
|
98 |
+
Run inference on the evaluation set using the best trained model.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
eval_dataset (fiftyone.core.dataset.Dataset): The evaluation dataset.
|
102 |
+
best_model (YOLO): The best trained YOLO model.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
The dataset eval_dataset with predictions
|
106 |
+
"""
|
107 |
+
eval_dataset.apply_model(best_model, label_field="predictions")
|
108 |
+
eval_dataset.save()
|
109 |
+
return eval_dataset
|
110 |
+
|
111 |
+
|
112 |
+
def eval_model(dataset_to_evaluate):
|
113 |
+
"""
|
114 |
+
Evaluate the model on the evaluation dataset.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
dataset_to_evaluate (fiftyone.core.dataset.Dataset): The evaluation dataset.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
the mean average precision (mAP) of the model on the evaluation dataset.
|
121 |
+
"""
|
122 |
+
current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
|
123 |
+
|
124 |
+
detection_results = dataset_to_evaluate.evaluate_detections(
|
125 |
+
gt_field="ground_truth",
|
126 |
+
pred_field="predictions",
|
127 |
+
eval_key=f"evalrun_{current_datetime}",
|
128 |
+
compute_mAP=True,
|
129 |
+
)
|
130 |
+
|
131 |
+
return detection_results.mAP()
|
132 |
+
|
133 |
+
def run():
|
134 |
+
"""
|
135 |
+
Main function to run the entire training and evaluation process.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
None
|
139 |
+
"""
|
140 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
141 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
142 |
+
config_path = os.path.join(script_dir, 'training_config.yaml')
|
143 |
+
with open(config_path, 'r') as file:
|
144 |
+
training_config = yaml.safe_load(file)
|
145 |
+
|
146 |
+
#train set
|
147 |
+
curated_train_dataset = prepare_dataset(name="Voxel51/Data-Centric-Visual-AI-Challenge-Train-Set")
|
148 |
+
|
149 |
+
#public eval set
|
150 |
+
public_eval_dataset = fouh.load_from_hub("Voxel51/DCVAI-Challenge-Public-Eval-Set")
|
151 |
+
|
152 |
+
N = len(curated_train_dataset)
|
153 |
+
|
154 |
+
best_trained_model = train_model(training_dataset=curated_train_dataset, training_config=training_config)
|
155 |
+
|
156 |
+
model_predictions = run_inference_on_eval_set(eval_dataset=public_eval_dataset, best_model=best_trained_model)
|
157 |
+
|
158 |
+
mAP_on_public_eval_set = eval_model(dataset_to_evaluate=model_predictions)
|
159 |
+
|
160 |
+
adjusted_mAP = (mAP_on_public_eval_set * log(N))/N
|
161 |
+
|
162 |
+
# need to add logic to log the score to the leaderboard for now, just print
|
163 |
+
|
164 |
+
print(f"The adjusted mean Average Precision (mAP) on the public evaluation set is: {adjusted_mAP:.4f}")
|
165 |
+
|
166 |
+
|
167 |
+
if __name__=="__main__":
|
168 |
+
run()
|
training_config.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset split
|
2 |
+
train_split: 0.9
|
3 |
+
val_split: 0.1
|
4 |
+
|
5 |
+
# Training parameters
|
6 |
+
train_params:
|
7 |
+
epochs: 50
|
8 |
+
batch: 16
|
9 |
+
imgsz: 640
|
10 |
+
lr0: 0.01
|
11 |
+
lrf: 0.01
|