import onnxruntime
import argparse
import os
from utils import *


def pre_process(img):
    """
    Preprocessing part of YOLOv3 for scaling and padding image as input to the network.
    Args:
        img (numpy.ndarray): H x W x C, image read with OpenCV
    Returns:
        padded_img (numpy.ndarray): preprocessed image to be fed to the network
    """
    img = letterbox(img, auto=False)[0]
    # Convert
    img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    img = np.ascontiguousarray(img)
    img = img.astype("float32")
    img = img / 255.0
    img = img[np.newaxis, :]
    return img


def post_process(x, conf_thres=0.1, iou_thres=0.6, multi_label=True,
                 classes=None, agnostic=False):
    """
    Post-processing part of YOLOv3 for generating final results from outputs of the network.
    Returns:
        pred (torch.tensor): n x 6, dets[:,:4] -> boxes, dets[:,4] -> scores, dets[:,5] -> class indices
    """
    stride = [32, 16, 8]
    anchors = [[10, 13, 16, 30, 33, 23],
               [30, 61, 62, 45, 59, 119],
               [116, 90, 156, 198, 373, 326]]
    temp = [13, 26, 52]
    res = []

    def create_grids(ng=(13, 13)):
        nx, ny = ng  # x and y grid size
        ng = torch.tensor(ng, dtype=torch.float)

        # build xy offsets
        yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
        grid = torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()

        return grid

    for i in range(3):
        out = torch.from_numpy(x[i])

        bs, _, ny, nx = out.shape  # bs, 255, 13, 13

        anchor = torch.Tensor(anchors[2 - i]).reshape(3, 2)
        anchor_vec = anchor / stride[i]
        anchor_wh = anchor_vec.view(1, 3, 1, 1, 2)

        grid = create_grids((nx, ny))

        out = out.view(
            bs, 3, 85, temp[i], temp[i]).permute(
            0, 1, 3, 4, 2).contiguous()  # prediction

        io = out.clone()

        io[..., :2] = torch.sigmoid(io[..., :2]) + grid
        io[..., 2:4] = torch.exp(io[..., 2:4]) * anchor_wh
        io[..., :4] *= stride[i]
        torch.sigmoid_(io[..., 4:])

        res.append(io.view(bs, -1, 85))

    pred = non_max_suppression(torch.cat(res, 1), conf_thres,
                               iou_thres, multi_label=multi_label,
                               classes=classes, agnostic=agnostic)

    return pred


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        prog='One image inference of onnx model')
    parser.add_argument(
        '--img',
        type=str,
        help='Path of input image')
    parser.add_argument(
        '--out',
        type=str,
        default='.',
        help='Path of out put image')
    parser.add_argument(
        "--ipu",
        action="store_true",
        help="Use IPU for inference.")
    parser.add_argument(
        "--provider_config",
        type=str,
        default="vaip_config.json",
        help="Path of the config file for seting provider_options.")
    parser.add_argument(
        "--onnx_path",
        type=str,
        default="yolov3-8.onnx",
        help="Path of the onnx model.")

    opt = parser.parse_args()
    with open('coco.names', 'r') as f:
        names = f.read()

    if opt.ipu:
        providers = ["VitisAIExecutionProvider"]
        provider_options = [{"config_file": opt.provider_config}]
    else:
        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        provider_options = None

    onnx_path = opt.onnx_path
    onnx_model = onnxruntime.InferenceSession(
        onnx_path, providers=providers, provider_options=provider_options)

    path = opt.img
    new_path = os.path.join(opt.out, "demo_infer.jpg")

    conf_thres, iou_thres, classes, agnostic_nms, max_det = 0.25, \
        0.45, None, False, 1000

    img0 = cv2.imread(path)
    img = pre_process(img0)
    onnx_input = {onnx_model.get_inputs()[0].name: img}
    onnx_output = onnx_model.run(None, onnx_input)
    pred = post_process(onnx_output, conf_thres,
                        iou_thres, multi_label=False,
                        classes=classes, agnostic=agnostic_nms)

    colors = [[random.randint(0, 255) for _ in range(3)]
              for _ in range(len(names))]
    det = pred[0]
    im0 = img0.copy()

    if len(det):
        # Rescale boxes from imgsz to im0 size
        det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

        # Write results
        for *xyxy, conf, cls in reversed(det):
            label = '%s %.2f' % (names[int(cls)], conf)
            plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])

    # Stream results
    cv2.imwrite(new_path, im0)