import os
import argparse

import onnxruntime as ort
from utils import *


CFG = {
    "name": "mobilenet0.25",
    "min_sizes": [[16, 32], [64, 128], [256, 512]],
    "steps": [8, 16, 32],
    "variance": [0.1, 0.2],
    "clip": False,
}
INPUT_SIZE = [608, 640]   #resize scale
DEVICE = torch.device("cpu")


def vis(img_raw, dets, vis_thres):
    """Visualization original image
    Args:
        img_raw: origin image
        dets: detections
        vis_thres: visualization threshold
    Returns: 
        visualization results
    """
    for b in dets:
        if b[4] < vis_thres:
            continue
        text = "{:.4f}".format(b[4])
        b = list(map(int, b))
        cv2.rectangle(img_raw, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
        cx = b[0]
        cy = b[1] + 12
        cv2.putText(img_raw, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255),)

        # landms
        cv2.circle(img_raw, (b[5], b[6]), 1, (0, 0, 255), 4)
        cv2.circle(img_raw, (b[7], b[8]), 1, (0, 255, 255), 4)
        cv2.circle(img_raw, (b[9], b[10]), 1, (255, 0, 255), 4)
        cv2.circle(img_raw, (b[11], b[12]), 1, (0, 255, 0), 4)
        cv2.circle(img_raw, (b[13], b[14]), 1, (255, 0, 0), 4)
    # save image
    if not os.path.exists("./results/"):
        os.makedirs("./results/")
    name = "./results/" + 'result' + ".jpg"
    cv2.imwrite(name, img_raw)


def Retinaface_inference(run_ort, args):
    """Infer an image with onnx seession
    Args:
        run_ort: Onnx session
        args: including image path and hyperparameters
    Returns: boxes_list, confidence_list, landm_list
        boxes_list = [[left, top, right, bottom]...]
        confidence_list = [[confidence]...]
        landm_list = [[landms(dim=10)]...]
    """
    img_raw = cv2.imread(args.image_path, cv2.IMREAD_COLOR)
    # preprocess
    img, scale, resize = preprocess(img_raw, INPUT_SIZE, DEVICE)
    # to NHWC
    img = np.transpose(img, (0, 2, 3, 1))
    # forward 
    outputs = run_ort.run(None, {run_ort.get_inputs()[0].name: img})
    # postprocess
    dets = postprocess(CFG, img, outputs, scale, resize, args.confidence_threshold, args.nms_threshold, DEVICE)
    
    # result list
    boxes = dets[:, :4]
    confidences = dets[:, 4:5]
    landms = dets[:, 5:]
    boxes_list = [box.tolist() for box in boxes]
    confidence_list = [confidence.tolist() for confidence in confidences]
    landm_list = [landm.tolist() for landm in landms]

    # save image
    if args.save_image:
        vis(img_raw, dets, args.vis_thres)

    return boxes_list, confidence_list, landm_list


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Retinaface")
    parser.add_argument(
        "-m",
        "--trained_model",
        default="./weights/RetinaFace_int.onnx",
        type=str,
        help="Trained state_dict file path to open",
    )
    parser.add_argument(
        "--image_path",
        default="./data/widerface/val/images/18--Concerts/18_Concerts_Concerts_18_38.jpg",
        type=str,
        help="image path",
    )
    parser.add_argument(
        "--confidence_threshold", 
        default=0.4, 
        type=float, 
        help="confidence_threshold" 
    )
    parser.add_argument(
        "--nms_threshold", 
        default=0.4, 
        type=float, 
        help="nms_threshold"
    )
    parser.add_argument(
        "-s",
        "--save_image",
        action="store_true",
        default=False, 
        help="show detection results",
    )
    parser.add_argument(
        "--vis_thres", 
        default=0.5, 
        type=float, 
        help="visualization_threshold"
    )
    parser.add_argument(
        "--ipu",
        action="store_true",
        help="Use IPU for inference.",
    )
    parser.add_argument(
        "--provider_config",
        type=str,
        default="vaip_config.json",
        help="Path of the config file for seting provider_options.",
    )
    
    args = parser.parse_args()

    if args.ipu:
        providers = ["VitisAIExecutionProvider"]
        provider_options = [{"config_file": args.provider_config}]
    else:
        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        provider_options = None
        
    print("Loading pretrained model from {}".format(args.trained_model))
    run_ort = ort.InferenceSession(args.trained_model, providers=providers, provider_options=provider_options)

    boxes_list, confidence_list, landm_list = Retinaface_inference(run_ort, args)
    print('inference done!')
    print(boxes_list, confidence_list, landm_list)