import os

import yaml

import fiftyone as fo
import fiftyone.utils.random as four
import fiftyone.utils.huggingface as fouh

#IMPLEMENT YOUR FUNCTIONS FOR DATA CURATION HERE, BELOW ARE JUST DUMMY FUNCTIONS AS EXAMPLES

def shuffle_data(dataset):
    """Shuffle the dataset"""
    return dataset.shuffle(seed=51)

def take_random_sample(dataset):
    """Take a sample from the dataset"""
    return dataset.take(size=10,seed=51)

# DEFINE YOUR TRAINING HYPERPARAMETERS IN THIS DICTIONARY
training_config = {
    # Dataset split
    "train_split": 0.9,
    "val_split": 0.1,

    # Training parameters
    "train_params": {
        "epochs": 1,
        "batch": 16,
        "imgsz": 640,
        "lr0": 0.01,
        "lrf": 0.01
    }
}


# WRAP YOUR DATASET CURATION FUNCTIONS IN THIS FUNCTION
def prepare_dataset():
    """
    Prepare the dataset for model training. 
    
    NOTE: You there are lines you must not modify in this function. They are marked with "DO NOT MODIFY".
    
    Args:
        name (str): The name of the dataset to load. Must be "Voxel51/Data-Centric-Visual-AI-Challenge-Train-Set".
    
    Returns:
        fiftyone.core.dataset.Dataset: The curated dataset.
    
    Note:
        The following code block MUST NOT be removed from your submission:

        This ensures that only the approved dataset is used for the competition.
    """
    
    # DO NOT MODIFY THIS LINE
    dataset = fouh.load_from_hub("/tmp/data/train")
    
    # WRAP YOUR DATA CURATION FUNCTIONS HERE
    dataset = shuffle_data(dataset)
    dataset = take_random_sample(dataset)
    
    # DO NOT MODIFY BELOW THIS LINE
    curated_dataset = dataset.clone(name="curated_dataset")
    
    curated_dataset.persistent = True

# DO NOT MODIFY THIS FUNCTION
def export_to_yolo_format(
    samples,
    classes,
    label_field="ground_truth",
    export_dir=".",
    splits=["train", "val"]
):
    """
    Export samples to YOLO format, optionally handling multiple data splits.

    NOTE: DO NOT MODIFY THIS FUNCTION.

    Args:
        samples (fiftyone.core.collections.SampleCollection): The dataset or samples to export.
        export_dir (str): The directory where the exported data will be saved.
        classes (list): A list of class names for the YOLO format.
        label_field (str, optional): The field in the samples that contains the labels.
            Defaults to "ground_truth".
        splits (str, list, optional): The split(s) to export. Can be a single split name (str) 
            or a list of split names. If None, all samples are exported as "val" split. 
            Defaults to None.

    Returns:
        None

    """
    if splits is None:
        splits = ["val"]
    elif isinstance(splits, str):
        splits = [splits]

    for split in splits:
        split_view = samples if split == "val" and splits == ["val"] else samples.match_tags(split)
        
        split_view.export(
            export_dir=export_dir,
            dataset_type=fo.types.YOLOv5Dataset,
            label_field=label_field,
            classes=classes,
            split=split
        )

# DO NOT MODIFY THIS FUNCTION
def train_model(training_config=training_config):
    """
    Train the YOLO model on the given dataset using the provided configuration.

    NOTE: DO NOT MODIFY THIS FUNCTION AT ALL OR YOUR SCRIPT WILL FAIL.
    """

    training_dataset = prepare_dataset()

    print("Splitting the dataset...")

    four.random_split(training_dataset, {"train": training_config['train_split'], "val": training_config['val_split']})
    
    print("Dataset split completed.")

    print("Exporting dataset to YOLO format...")

    export_to_yolo_format(
        samples=training_dataset,
        classes=training_dataset.default_classes,
    )

    print("Dataset export completed.")

    print("Initializing the YOLO model...")

    #DO NOT MODIFY THIS LINE
    model = YOLO(
        model="/tmp/data/yolo11m.pt",
        
    )
    
    print("Model initialized.")

    print("Starting model training...")

    results = model.train(
        data="dataset.yaml",
        **training_config['train_params']
    )

    print("Model training completed.")

    best_model_path = str(results.save_dir / "weights/best.pt")

    print(f"Best model saved to: {best_model_path}")

# DO NOT MODIFY THE BELOW
if __name__=="__main__":
    train_model()