# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# MAE: https://github.com/facebookresearch/mae
# --------------------------------------------------------
import math
from typing import Iterable
import os
import matplotlib.pyplot as plt
import random
import torch
import numpy as np
import time
import base64
from io import BytesIO

import util.misc as misc
import util.lr_sched as lr_sched

from pytorch3d.structures import Pointclouds
from pytorch3d.vis.plotly_vis import plot_scene
from pytorch3d.transforms import RotateAxisAngle
from pytorch3d.io import IO


def evaluate_points(predicted_xyz, gt_xyz, dist_thres):
    if predicted_xyz.shape[0] == 0:
        return 0.0, 0.0, 0.0
    slice_size = 1000
    precision = 0.0
    for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))):
        start = slice_size * i
        end   = slice_size * (i + 1)
        dist = ((predicted_xyz[start:end, None] - gt_xyz[None]) ** 2.0).sum(axis=-1) ** 0.5
        precision += ((dist < dist_thres).sum(axis=1) > 0).sum()
    precision /= predicted_xyz.shape[0]

    recall = 0.0
    for i in range(int(np.ceil(predicted_xyz.shape[0] / slice_size))):
        start = slice_size * i
        end   = slice_size * (i + 1)
        dist = ((predicted_xyz[:, None] - gt_xyz[None, start:end]) ** 2.0).sum(axis=-1) ** 0.5
        recall += ((dist < dist_thres).sum(axis=0) > 0).sum()
    recall /= gt_xyz.shape[0]
    return precision, recall, get_f1(precision, recall)

def aug_xyz(seen_xyz, unseen_xyz, args, is_train):
    degree_x = 0
    degree_y = 0
    degree_z = 0
    if is_train:
        r_delta = args.random_scale_delta
        scale = torch.tensor([
            random.uniform(1.0 - r_delta, 1.0 + r_delta),
            random.uniform(1.0 - r_delta, 1.0 + r_delta),
            random.uniform(1.0 - r_delta, 1.0 + r_delta),
        ], device=seen_xyz.device)

        if args.use_hypersim:
            shift = 0
        else:
            degree_x = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)
            degree_y = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)
            degree_z = random.randrange(-args.random_rotate_degree, args.random_rotate_degree + 1)

            r_shift = args.random_shift
            shift = torch.tensor([[[
                random.uniform(-r_shift, r_shift),
                random.uniform(-r_shift, r_shift),
                random.uniform(-r_shift, r_shift),
            ]]], device=seen_xyz.device)
        seen_xyz = seen_xyz * scale + shift
        unseen_xyz = unseen_xyz * scale + shift

    B, H, W, _ = seen_xyz.shape
    return [
        rotate(seen_xyz.reshape((B, -1, 3)), degree_x, degree_y, degree_z).reshape((B, H, W, 3)),
        rotate(unseen_xyz, degree_x, degree_y, degree_z),
    ]


def rotate(sample, degree_x, degree_y, degree_z):
    for degree, axis in [(degree_x, "X"), (degree_y, "Y"), (degree_z, "Z")]:
        if degree != 0:
            sample = RotateAxisAngle(degree, axis=axis).to(sample.device).transform_points(sample)
    return sample


def get_grid(B, device, co3d_world_size, granularity):
    N = int(np.ceil(2 * co3d_world_size / granularity))
    grid_unseen_xyz = torch.zeros((N, N, N, 3), device=device)
    for i in range(N):
        grid_unseen_xyz[i, :, :, 0] = i
    for j in range(N):
        grid_unseen_xyz[:, j, :, 1] = j
    for k in range(N):
        grid_unseen_xyz[:, :, k, 2] = k
    grid_unseen_xyz -= (N / 2.0)
    grid_unseen_xyz /= (N / 2.0) / co3d_world_size
    grid_unseen_xyz = grid_unseen_xyz.reshape((1, -1, 3)).repeat(B, 1, 1)
    return grid_unseen_xyz


def run_viz(model, data_loader, device, args, epoch):
    epoch_start_time = time.time()
    model.eval()
    os.system(f'mkdir {args.job_dir}/viz')

    print('Visualization data_loader length:', len(data_loader))
    dataset = data_loader.dataset
    for sample_idx, samples in enumerate(data_loader):
        if sample_idx >= args.max_n_viz_obj:
            break
        seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args, is_viz=True)

        pred_occupy = []
        pred_colors = []
        (model.module if hasattr(model, "module") else model).clear_cache()

        # don't forward all at once to avoid oom
        max_n_queries_fwd = 2000

        total_n_passes = int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd))
        for p_idx in range(total_n_passes):
            p_start = p_idx     * max_n_queries_fwd
            p_end = (p_idx + 1) * max_n_queries_fwd
            cur_unseen_xyz = unseen_xyz[:, p_start:p_end]
            cur_unseen_rgb = unseen_rgb[:, p_start:p_end].zero_()
            cur_labels = labels[:, p_start:p_end].zero_()

            with torch.no_grad():
                _, pred, = model(
                    seen_images=seen_images,
                    seen_xyz=seen_xyz,
                    unseen_xyz=cur_unseen_xyz,
                    unseen_rgb=cur_unseen_rgb,
                    unseen_occupy=cur_labels,
                    cache_enc=args.run_viz,
                    valid_seen_xyz=valid_seen_xyz,
                )

            cur_occupy_out = pred[..., 0]

            if args.regress_color:
                cur_color_out = pred[..., 1:].reshape((-1, 3))
            else:
                cur_color_out = pred[..., 1:].reshape((-1, 3, 256)).max(dim=2)[1] / 255.0
            pred_occupy.append(cur_occupy_out)
            pred_colors.append(cur_color_out)

        rank = misc.get_rank()
        prefix = f'{args.job_dir}/viz/' + dataset.dataset_split + f'_ep{epoch}_rank{rank}_i{sample_idx}'

        img = (seen_images[0].permute(1, 2, 0) * 255).cpu().numpy().copy().astype(np.uint8)

        gt_xyz = samples[1][0].to(device).reshape(-1, 3)
        gt_rgb = samples[1][1].to(device).reshape(-1, 3)
        mesh_xyz = samples[2].to(device).reshape(-1, 3) if args.use_hypersim else None

        with open(prefix + '.html', 'a') as f:
            generate_html(
                img,
                seen_xyz, seen_images,
                torch.cat(pred_occupy, dim=1),
                torch.cat(pred_colors, dim=0),
                unseen_xyz,
                f,
                gt_xyz=gt_xyz,
                gt_rgb=gt_rgb,
                mesh_xyz=mesh_xyz,
            )
    print("Visualization epoch time:", time.time() - epoch_start_time)


def get_f1(precision, recall):
    if (precision + recall) == 0:
        return 0.0
    return 2.0 * precision * recall / (precision + recall)


def generate_plot(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz,
        gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9],
        pointcloud_marker_size=2,
    ):
    # if img is not None:
    #     fig = plt.figure()
    #     plt.imshow(img)
    #     tmpfile = BytesIO()
    #     fig.savefig(tmpfile, format='jpg')
    #     encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')

    #     html = '<img src=\'data:image/png;base64,{}\'>'.format(encoded)
    #     f.write(html)
    #     plt.close()

    clouds = {"MCC Output": {}}
    # Seen
    if seen_xyz is not None:
        seen_xyz = seen_xyz.reshape((-1, 3)).cpu()
        seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu()
        good_seen = seen_xyz[:, 0] != -100

        seen_pc = Pointclouds(
            points=seen_xyz[good_seen][None],
            features=seen_rgb[good_seen][None],
        )
        clouds["MCC Output"]["seen"] = seen_pc

    # GT points
    if gt_xyz is not None:
        subset_gt = random.sample(range(gt_xyz.shape[0]), 10000)
        gt_pc = Pointclouds(
            points=gt_xyz[subset_gt][None],
            features=gt_rgb[subset_gt][None],
        )
        clouds["MCC Output"]["GT points"] = gt_pc

    # GT meshes
    if mesh_xyz is not None:
        subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000)
        mesh_pc = Pointclouds(
            points=mesh_xyz[subset_mesh][None],
        )
        clouds["MCC Output"]["GT mesh"] = mesh_pc

    pred_occ = torch.nn.Sigmoid()(pred_occ).cpu()
    for t in score_thresholds:
        pos = pred_occ > t

        points = unseen_xyz[pos].reshape((-1, 3))
        features = pred_rgb[None][pos].reshape((-1, 3))
        good_points = points[:, 0] != -100

        if good_points.sum() == 0:
            continue

        pc = Pointclouds(
            points=points[good_points][None].cpu(),
            features=features[good_points][None].cpu(),
        )

        clouds["MCC Output"][f"pred_{t}"] = pc
        IO().save_pointcloud(pc, "output_pointcloud.ply")

    plt.figure()
    try:
        fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2)
        fig.update_layout(height=1000, width=1000)
        return fig
    except Exception as e:
        print('writing failed', e)
    try:
        plt.close()
    except:
        pass


def generate_html(img, seen_xyz, seen_rgb, pred_occ, pred_rgb, unseen_xyz, f,
        gt_xyz=None, gt_rgb=None, mesh_xyz=None, score_thresholds=[0.1, 0.3, 0.5, 0.7, 0.9],
        pointcloud_marker_size=2,
    ):
    if img is not None:
        fig = plt.figure()
        plt.imshow(img)
        tmpfile = BytesIO()
        fig.savefig(tmpfile, format='jpg')
        encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8')

        html = '<img src=\'data:image/png;base64,{}\'>'.format(encoded)
        f.write(html)
        plt.close()

    clouds = {"MCC Output": {}}
    # Seen
    if seen_xyz is not None:
        seen_xyz = seen_xyz.reshape((-1, 3)).cpu()
        seen_rgb = torch.nn.functional.interpolate(seen_rgb, (112, 112)).permute(0, 2, 3, 1).reshape((-1, 3)).cpu()
        good_seen = seen_xyz[:, 0] != -100

        seen_pc = Pointclouds(
            points=seen_xyz[good_seen][None],
            features=seen_rgb[good_seen][None],
        )
        clouds["MCC Output"]["seen"] = seen_pc

    # GT points
    if gt_xyz is not None:
        subset_gt = random.sample(range(gt_xyz.shape[0]), 10000)
        gt_pc = Pointclouds(
            points=gt_xyz[subset_gt][None],
            features=gt_rgb[subset_gt][None],
        )
        clouds["MCC Output"]["GT points"] = gt_pc

    # GT meshes
    if mesh_xyz is not None:
        subset_mesh = random.sample(range(mesh_xyz.shape[0]), 10000)
        mesh_pc = Pointclouds(
            points=mesh_xyz[subset_mesh][None],
        )
        clouds["MCC Output"]["GT mesh"] = mesh_pc

    pred_occ = torch.nn.Sigmoid()(pred_occ).cpu()
    for t in score_thresholds:
        pos = pred_occ > t

        points = unseen_xyz[pos].reshape((-1, 3))
        features = pred_rgb[None][pos].reshape((-1, 3))
        good_points = points[:, 0] != -100

        if good_points.sum() == 0:
            continue

        pc = Pointclouds(
            points=points[good_points][None].cpu(),
            features=features[good_points][None].cpu(),
        )

        clouds["MCC Output"][f"pred_{t}"] = pc

    plt.figure()
    try:
        fig = plot_scene(clouds, pointcloud_marker_size=pointcloud_marker_size, pointcloud_max_points=20000 * 2)
        fig.update_layout(height=1000, width=1000)
        html_string = fig.to_html(full_html=False, include_plotlyjs="cnd")
        f.write(html_string)
        return fig, plt
    except Exception as e:
        print('writing failed', e)
    try:
        plt.close()
    except:
        pass


def train_one_epoch(model: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler,
                    args=None):
    epoch_start_time = time.time()
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))

    accum_iter = args.accum_iter

    optimizer.zero_grad()

    print('Training data_loader length:', len(data_loader))
    for data_iter_step, samples in enumerate(data_loader):
        # we use a per iteration (instead of per epoch) lr scheduler
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
        seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=True, args=args)

        with torch.cuda.amp.autocast():
            loss, _ = model(
                seen_images=seen_images,
                seen_xyz=seen_xyz,
                unseen_xyz=unseen_xyz,
                unseen_rgb=unseen_rgb,
                unseen_occupy=labels,
                valid_seen_xyz=valid_seen_xyz,
            )

        loss_value = loss.item()
        if not math.isfinite(loss_value):
            print("Warning: Loss is {}".format(loss_value))
            loss *= 0.0
            loss_value = 100.0

        loss /= accum_iter
        loss_scaler(loss, optimizer, parameters=model.parameters(),
                    clip_grad=args.clip_grad,
                    update_grad=(data_iter_step + 1) % accum_iter == 0,
                    verbose=(data_iter_step % 100) == 0)

        if (data_iter_step + 1) % accum_iter == 0:
            optimizer.zero_grad()

        torch.cuda.synchronize()

        metric_logger.update(loss=loss_value)

        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)

        if data_iter_step == 30:
            os.system('nvidia-smi')
            os.system('free -g')
        if args.debug and data_iter_step == 5:
            break

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    print("Training epoch time:", time.time() - epoch_start_time)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def eval_one_epoch(
        model: torch.nn.Module,
        data_loader: Iterable,
        device: torch.device,
        args=None
    ):
    epoch_start_time = time.time()
    model.train(False)

    metric_logger = misc.MetricLogger(delimiter="  ")

    print('Eval len(data_loader):', len(data_loader))

    for data_iter_step, samples in enumerate(data_loader):
        seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(samples, device, is_train=False, args=args)

        # don't forward all at once to avoid oom
        max_n_queries_fwd = 5000
        all_loss, all_preds = [], []
        for p_idx in range(int(np.ceil(unseen_xyz.shape[1] / max_n_queries_fwd))):
            p_start = p_idx     * max_n_queries_fwd
            p_end = (p_idx + 1) * max_n_queries_fwd
            cur_unseen_xyz = unseen_xyz[:, p_start:p_end]
            cur_unseen_rgb = unseen_rgb[:, p_start:p_end]
            cur_labels = labels[:, p_start:p_end]

            with torch.no_grad():
                loss, pred = model(
                    seen_images=seen_images,
                    seen_xyz=seen_xyz,
                    unseen_xyz=cur_unseen_xyz,
                    unseen_rgb=cur_unseen_rgb,
                    unseen_occupy=cur_labels,
                    valid_seen_xyz=valid_seen_xyz,
                )
            all_loss.append(loss)
            all_preds.append(pred)

        loss = sum(all_loss) / len(all_loss)
        pred = torch.cat(all_preds, dim=1)

        B = pred.shape[0]

        gt_xyz = samples[1][0].to(device).reshape((B, -1, 3))
        if args.use_hypersim:
            mesh_xyz = samples[2].to(device).reshape((B, -1, 3))

        s_thres = args.eval_score_threshold
        d_thres = args.eval_dist_threshold

        for b_idx in range(B):
            geometry_metrics = {}
            predicted_idx = torch.nn.Sigmoid()(pred[b_idx, :, 0]) > s_thres
            predicted_xyz = unseen_xyz[b_idx, predicted_idx]

            precision, recall, f1 = evaluate_points(predicted_xyz, gt_xyz[b_idx], d_thres)
            geometry_metrics[f'd{d_thres}_s{s_thres}_point_pr'] = precision
            geometry_metrics[f'd{d_thres}_s{s_thres}_point_rc'] = recall
            geometry_metrics[f'd{d_thres}_s{s_thres}_point_f1'] = f1

            if args.use_hypersim:
                precision, recall, f1 = evaluate_points(predicted_xyz, mesh_xyz[b_idx], d_thres)
                geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_pr'] = precision
                geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_rc'] = recall
                geometry_metrics[f'd{d_thres}_s{s_thres}_mesh_f1'] = f1

            metric_logger.update(**geometry_metrics)

        loss_value = loss.item()

        torch.cuda.synchronize()
        metric_logger.update(loss=loss_value)

        if args.debug and data_iter_step == 5:
            break

    metric_logger.synchronize_between_processes()
    print("Validation averaged stats:", metric_logger)
    print("Val epoch time:", time.time() - epoch_start_time)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def sample_uniform_semisphere(B, N, semisphere_size, device):
    for _ in range(100):
        points = torch.empty(B * N * 3, 3, device=device).uniform_(-semisphere_size, semisphere_size)
        points[..., 2] = points[..., 2].abs()
        dist = (points ** 2.0).sum(axis=-1) ** 0.5
        if (dist < semisphere_size).sum() >= B * N:
            return points[dist < semisphere_size][:B * N].reshape((B, N, 3))
        else:
            print('resampling sphere')


def get_grid_semisphere(B, granularity, semisphere_size, device):
    n_grid_pts = int(semisphere_size / granularity) * 2 + 1
    grid_unseen_xyz = torch.zeros((n_grid_pts, n_grid_pts, n_grid_pts // 2 + 1, 3), device=device)
    for i in range(n_grid_pts):
        grid_unseen_xyz[i, :, :, 0] = i
        grid_unseen_xyz[:, i, :, 1] = i
    for i in range(n_grid_pts // 2 + 1):
        grid_unseen_xyz[:, :, i, 2] = i
    grid_unseen_xyz[..., :2] -= (n_grid_pts // 2.0)
    grid_unseen_xyz *= granularity
    dist = (grid_unseen_xyz ** 2.0).sum(axis=-1) ** 0.5
    grid_unseen_xyz = grid_unseen_xyz[dist <= semisphere_size]
    return grid_unseen_xyz[None].repeat(B, 1, 1)


def get_min_dist(a, b, slice_size=1000):
    all_min, all_idx = [], []
    for i in range(int(np.ceil(a.shape[1] / slice_size))):
        start = slice_size * i
        end   = slice_size * (i + 1)
        # B, n_queries, n_gt
        dist = ((a[:, start:end] - b) ** 2.0).sum(axis=-1) ** 0.5
        # B, n_queries
        cur_min, cur_idx = dist.min(axis=2)
        all_min.append(cur_min)
        all_idx.append(cur_idx)
    return torch.cat(all_min, dim=1), torch.cat(all_idx, dim=1)


def construct_uniform_semisphere(gt_xyz, gt_rgb, semisphere_size, n_queries, dist_threshold, is_train, granularity):
    B = gt_xyz.shape[0]
    device = gt_xyz.device
    if is_train:
        unseen_xyz = sample_uniform_semisphere(B, n_queries, semisphere_size, device)
    else:
        unseen_xyz = get_grid_semisphere(B, granularity, semisphere_size, device)
    dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None])
    labels = dist < dist_threshold
    unseen_rgb = torch.zeros_like(unseen_xyz)
    unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels]
    return unseen_xyz, unseen_rgb, labels.float()


def construct_uniform_grid(gt_xyz, gt_rgb, co3d_world_size, n_queries, dist_threshold, is_train, granularity):
    B = gt_xyz.shape[0]
    device = gt_xyz.device
    if is_train:
        unseen_xyz = torch.empty((B, n_queries, 3), device=device).uniform_(-co3d_world_size, co3d_world_size)
    else:
        unseen_xyz = get_grid(B, device, co3d_world_size, granularity)
    dist, idx_to_gt = get_min_dist(unseen_xyz[:, :, None], gt_xyz[:, None])
    labels = dist < dist_threshold
    unseen_rgb = torch.zeros_like(unseen_xyz)
    unseen_rgb[labels] = torch.gather(gt_rgb, 1, idx_to_gt.unsqueeze(-1).repeat(1, 1, 3))[labels]
    return unseen_xyz, unseen_rgb, labels.float()


def prepare_data(samples, device, is_train, args, is_viz=False):
    # Seen
    seen_xyz, seen_rgb = samples[0][0].to(device), samples[0][1].to(device)
    valid_seen_xyz = torch.isfinite(seen_xyz.sum(axis=-1))
    seen_xyz[~valid_seen_xyz] = -100
    B = seen_xyz.shape[0]
    # Gt
    gt_xyz, gt_rgb = samples[1][0].to(device).reshape(B, -1, 3), samples[1][1].to(device).reshape(B, -1, 3)

    sampling_func = construct_uniform_semisphere if args.use_hypersim else construct_uniform_grid
    unseen_xyz, unseen_rgb, labels = sampling_func(
        gt_xyz, gt_rgb,
        args.semisphere_size if args.use_hypersim else args.co3d_world_size,
        args.n_queries,
        args.train_dist_threshold,
        is_train,
        args.viz_granularity if is_viz else args.eval_granularity,
    )

    if is_train:
        seen_xyz, unseen_xyz = aug_xyz(seen_xyz, unseen_xyz, args, is_train=is_train)

        # Random Flip
        if random.random() < 0.5:
            seen_xyz[..., 0] *= -1
            unseen_xyz[..., 0] *= -1
            seen_xyz = torch.flip(seen_xyz, [2])
            valid_seen_xyz = torch.flip(valid_seen_xyz, [2])
            seen_rgb = torch.flip(seen_rgb, [3])

    return seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_rgb