import numpy as np
from norfair import AbsolutePaths, Paths, Tracker, Video
from norfair.camera_motion import HomographyTransformationGetter, MotionEstimator
from norfair.distances import create_normalized_mean_euclidean_distance

from custom_models import YOLO, yolo_detections_to_norfair_detections
from demo_utils.configuration import (
    DISTANCE_THRESHOLD_BBOX,
    DISTANCE_THRESHOLD_CENTROID,
    examples,
    models_path,
    style,
)
from demo_utils.draw import center, draw


def inference(
    input_video: str,
    model: str = "YOLOv7",
    features: str = [0, 1],
    track_points: str = "Bounding box",
    model_threshold: float = 0.25,
):
    coord_transformations = None
    paths_drawer = None
    fix_paths = False
    classes = None
    track_points = style[track_points]
    model = YOLO(models_path[model])
    video = Video(input_path=input_video)

    motion_estimation = len(features) > 0 and (
        features[0] == 0 or (len(features) > 1 and features[1] == 0)
    )

    drawing_paths = len(features) > 0 and (
        features[0] == 1 or (len(features) > 1 and features[1] == 1)
    )

    if motion_estimation:
        transformations_getter = HomographyTransformationGetter()

        motion_estimator = MotionEstimator(
            max_points=500, min_distance=7, transformations_getter=transformations_getter
        )

    distance_function = "iou" if track_points == style["Bounding box"] else "euclidean"
    distance_threshold = (
        DISTANCE_THRESHOLD_BBOX
        if track_points == style["Bounding box"]
        else DISTANCE_THRESHOLD_CENTROID
    )

    if motion_estimation and drawing_paths:
        fix_paths = True

    # Examples configuration
    for example in examples:
        if example not in input_video:
            continue
        fix_paths = examples[example]["absolute_path"]
        distance_threshold = examples[example]["distance_threshold"]
        classes = examples[example]["classes"]

        print(f"Set config to {example}: {fix_paths} {distance_threshold} {classes}")
        break

    tracker = Tracker(
        distance_function=distance_function,
        distance_threshold=distance_threshold,
    )

    if drawing_paths:
        paths_drawer = Paths(center, attenuation=0.01)

    if fix_paths:
        paths_drawer = AbsolutePaths(max_history=50, thickness=2)

    for frame in video:
        yolo_detections = model(
            frame,
            conf_threshold=model_threshold,
            iou_threshold=0.45,
            image_size=720,
            classes=classes,
        )

        detections = yolo_detections_to_norfair_detections(
            yolo_detections, track_points=track_points
        )

        tracked_objects = tracker.update(
            detections=detections, coord_transformations=coord_transformations
        )

        if motion_estimation:
            mask = np.ones(frame.shape[:2], frame.dtype)
            if track_points == "bbox":
                for det in detections:
                    i = det.points.astype(int)
                    mask[i[0, 1] : i[1, 1], i[0, 0] : i[1, 0]] = 0
            coord_transformations = motion_estimator.update(frame, mask)

        frame = draw(
            paths_drawer,
            track_points,
            frame,
            detections,
            tracked_objects,
            coord_transformations,
            fix_paths,
        )
        video.write(frame)

    base_file_name = input_video.split("/")[-1].split(".")[0]
    file_name = base_file_name + "_out.mp4"

    return file_name