import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import models
from models.base import BaseModel
from models.utils import chunk_batch
from systems.utils import update_module_step
from nerfacc import (
    ContractionType,
    OccupancyGrid,
    ray_marching,
    render_weight_from_density,
    render_weight_from_alpha,
    accumulate_along_rays,
)
from nerfacc.intersection import ray_aabb_intersect

import pdb


class VarianceNetwork(nn.Module):
    def __init__(self, config):
        super(VarianceNetwork, self).__init__()
        self.config = config
        self.init_val = self.config.init_val
        self.register_parameter(
            "variance", nn.Parameter(torch.tensor(self.config.init_val))
        )
        self.modulate = self.config.get("modulate", False)
        if self.modulate:
            self.mod_start_steps = self.config.mod_start_steps
            self.reach_max_steps = self.config.reach_max_steps
            self.max_inv_s = self.config.max_inv_s

    @property
    def inv_s(self):
        val = torch.exp(self.variance * 10.0)
        if self.modulate and self.do_mod:
            val = val.clamp_max(self.mod_val)
        return val

    def forward(self, x):
        return torch.ones([len(x), 1], device=self.variance.device) * self.inv_s

    def update_step(self, epoch, global_step):
        if self.modulate:
            self.do_mod = global_step > self.mod_start_steps
            if not self.do_mod:
                self.prev_inv_s = self.inv_s.item()
            else:
                self.mod_val = min(
                    (global_step / self.reach_max_steps)
                    * (self.max_inv_s - self.prev_inv_s)
                    + self.prev_inv_s,
                    self.max_inv_s,
                )


@models.register("neus")
class NeuSModel(BaseModel):
    def setup(self):
        self.geometry = models.make(self.config.geometry.name, self.config.geometry)
        self.texture = models.make(self.config.texture.name, self.config.texture)
        self.geometry.contraction_type = ContractionType.AABB

        if self.config.learned_background:
            self.geometry_bg = models.make(
                self.config.geometry_bg.name, self.config.geometry_bg
            )
            self.texture_bg = models.make(
                self.config.texture_bg.name, self.config.texture_bg
            )
            self.geometry_bg.contraction_type = ContractionType.UN_BOUNDED_SPHERE
            self.near_plane_bg, self.far_plane_bg = 0.1, 1e3
            self.cone_angle_bg = (
                10
                ** (math.log10(self.far_plane_bg) / self.config.num_samples_per_ray_bg)
                - 1.0
            )
            self.render_step_size_bg = 0.01

        self.variance = VarianceNetwork(self.config.variance)
        self.register_buffer(
            "scene_aabb",
            torch.as_tensor(
                [
                    -self.config.radius,
                    -self.config.radius,
                    -self.config.radius,
                    self.config.radius,
                    self.config.radius,
                    self.config.radius,
                ],
                dtype=torch.float32,
            ),
        )
        if self.config.grid_prune:
            self.occupancy_grid = OccupancyGrid(
                roi_aabb=self.scene_aabb,
                resolution=128,
                contraction_type=ContractionType.AABB,
            )
            if self.config.learned_background:
                self.occupancy_grid_bg = OccupancyGrid(
                    roi_aabb=self.scene_aabb,
                    resolution=256,
                    contraction_type=ContractionType.UN_BOUNDED_SPHERE,
                )
        self.randomized = self.config.randomized
        self.background_color = None
        self.render_step_size = (
            1.732 * 2 * self.config.radius / self.config.num_samples_per_ray
        )

    def update_step(self, epoch, global_step):
        update_module_step(self.geometry, epoch, global_step)
        update_module_step(self.texture, epoch, global_step)
        if self.config.learned_background:
            update_module_step(self.geometry_bg, epoch, global_step)
            update_module_step(self.texture_bg, epoch, global_step)
        update_module_step(self.variance, epoch, global_step)

        cos_anneal_end = self.config.get("cos_anneal_end", 0)
        self.cos_anneal_ratio = (
            1.0 if cos_anneal_end == 0 else min(1.0, global_step / cos_anneal_end)
        )

        def occ_eval_fn(x):
            sdf = self.geometry(x, with_grad=False, with_feature=False)
            inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6)
            inv_s = inv_s.expand(sdf.shape[0], 1)
            estimated_next_sdf = sdf[..., None] - self.render_step_size * 0.5
            estimated_prev_sdf = sdf[..., None] + self.render_step_size * 0.5
            prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
            next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
            p = prev_cdf - next_cdf
            c = prev_cdf
            alpha = ((p + 1e-5) / (c + 1e-5)).view(-1, 1).clip(0.0, 1.0)
            return alpha

        def occ_eval_fn_bg(x):
            density, _ = self.geometry_bg(x)
            # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size_bg) based on taylor series
            return density[..., None] * self.render_step_size_bg

        if self.training and self.config.grid_prune:
            self.occupancy_grid.every_n_step(
                step=global_step,
                occ_eval_fn=occ_eval_fn,
                occ_thre=self.config.get("grid_prune_occ_thre", 0.01),
            )
            if self.config.learned_background:
                self.occupancy_grid_bg.every_n_step(
                    step=global_step,
                    occ_eval_fn=occ_eval_fn_bg,
                    occ_thre=self.config.get("grid_prune_occ_thre_bg", 0.01),
                )

    def isosurface(self):
        mesh = self.geometry.isosurface()
        return mesh

    def get_alpha(self, sdf, normal, dirs, dists):
        inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip(
            1e-6, 1e6
        )  # Single parameter
        inv_s = inv_s.expand(sdf.shape[0], 1)

        true_cos = (dirs * normal).sum(-1, keepdim=True)

        # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
        # the cos value "not dead" at the beginning training iterations, for better convergence.
        iter_cos = -(
            F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio)
            + F.relu(-true_cos) * self.cos_anneal_ratio
        )  # always non-positive

        # Estimate signed distances at section points
        estimated_next_sdf = sdf[..., None] + iter_cos * dists.reshape(-1, 1) * 0.5
        estimated_prev_sdf = sdf[..., None] - iter_cos * dists.reshape(-1, 1) * 0.5

        prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
        next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)

        p = prev_cdf - next_cdf
        c = prev_cdf

        alpha = ((p + 1e-5) / (c + 1e-5)).view(-1).clip(0.0, 1.0)
        return alpha

    def forward_bg_(self, rays):
        n_rays = rays.shape[0]
        rays_o, rays_d = rays[:, 0:3], rays[:, 3:6]  # both (N_rays, 3)

        def sigma_fn(t_starts, t_ends, ray_indices):
            ray_indices = ray_indices.long()
            t_origins = rays_o[ray_indices]
            t_dirs = rays_d[ray_indices]
            positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
            density, _ = self.geometry_bg(positions)
            return density[..., None]

        _, t_max = ray_aabb_intersect(rays_o, rays_d, self.scene_aabb)
        # if the ray intersects with the bounding box, start from the farther intersection point
        # otherwise start from self.far_plane_bg
        # note that in nerfacc t_max is set to 1e10 if there is no intersection
        near_plane = torch.where(t_max > 1e9, self.near_plane_bg, t_max)
        with torch.no_grad():
            ray_indices, t_starts, t_ends = ray_marching(
                rays_o,
                rays_d,
                scene_aabb=None,
                grid=self.occupancy_grid_bg if self.config.grid_prune else None,
                sigma_fn=sigma_fn,
                near_plane=near_plane,
                far_plane=self.far_plane_bg,
                render_step_size=self.render_step_size_bg,
                stratified=self.randomized,
                cone_angle=self.cone_angle_bg,
                alpha_thre=0.0,
            )

        ray_indices = ray_indices.long()
        t_origins = rays_o[ray_indices]
        t_dirs = rays_d[ray_indices]
        midpoints = (t_starts + t_ends) / 2.0
        positions = t_origins + t_dirs * midpoints
        intervals = t_ends - t_starts

        density, feature = self.geometry_bg(positions)
        rgb = self.texture_bg(feature, t_dirs)

        weights = render_weight_from_density(
            t_starts, t_ends, density[..., None], ray_indices=ray_indices, n_rays=n_rays
        )
        opacity = accumulate_along_rays(
            weights, ray_indices, values=None, n_rays=n_rays
        )
        depth = accumulate_along_rays(
            weights, ray_indices, values=midpoints, n_rays=n_rays
        )
        comp_rgb = accumulate_along_rays(
            weights, ray_indices, values=rgb, n_rays=n_rays
        )
        comp_rgb = comp_rgb + self.background_color * (1.0 - opacity)

        out = {
            "comp_rgb": comp_rgb,
            "opacity": opacity,
            "depth": depth,
            "rays_valid": opacity > 0,
            "num_samples": torch.as_tensor(
                [len(t_starts)], dtype=torch.int32, device=rays.device
            ),
        }

        if self.training:
            out.update(
                {
                    "weights": weights.view(-1),
                    "points": midpoints.view(-1),
                    "intervals": intervals.view(-1),
                    "ray_indices": ray_indices.view(-1),
                }
            )

        return out

    def forward_(self, rays):
        n_rays = rays.shape[0]
        rays_o, rays_d = rays[:, 0:3], rays[:, 3:6]  # both (N_rays, 3)

        with torch.no_grad():
            ray_indices, t_starts, t_ends = ray_marching(
                rays_o,
                rays_d,
                scene_aabb=self.scene_aabb,
                grid=self.occupancy_grid if self.config.grid_prune else None,
                alpha_fn=None,
                near_plane=None,
                far_plane=None,
                render_step_size=self.render_step_size,
                stratified=self.randomized,
                cone_angle=0.0,
                alpha_thre=0.0,
            )

        ray_indices = ray_indices.long()
        t_origins = rays_o[ray_indices]
        t_dirs = rays_d[ray_indices]
        midpoints = (t_starts + t_ends) / 2.0
        positions = t_origins + t_dirs * midpoints
        dists = t_ends - t_starts

        if self.config.geometry.grad_type == "finite_difference":
            sdf, sdf_grad, feature, sdf_laplace = self.geometry(
                positions, with_grad=True, with_feature=True, with_laplace=True
            )
        else:
            sdf, sdf_grad, feature = self.geometry(
                positions, with_grad=True, with_feature=True
            )

        normal = F.normalize(sdf_grad, p=2, dim=-1)
        alpha = self.get_alpha(sdf, normal, t_dirs, dists)[..., None]
        rgb = self.texture(feature, t_dirs, normal)

        weights = render_weight_from_alpha(
            alpha, ray_indices=ray_indices, n_rays=n_rays
        )
        opacity = accumulate_along_rays(
            weights, ray_indices, values=None, n_rays=n_rays
        )
        depth = accumulate_along_rays(
            weights, ray_indices, values=midpoints, n_rays=n_rays
        )
        comp_rgb = accumulate_along_rays(
            weights, ray_indices, values=rgb, n_rays=n_rays
        )

        comp_normal = accumulate_along_rays(
            weights, ray_indices, values=normal, n_rays=n_rays
        )
        comp_normal = F.normalize(comp_normal, p=2, dim=-1)

        pts_random = (
            torch.rand([1024 * 2, 3]).to(sdf.dtype).to(sdf.device) * 2 - 1
        )  # normalized to (-1, 1)

        if self.config.geometry.grad_type == "finite_difference":
            random_sdf, random_sdf_grad, _ = self.geometry(
                pts_random, with_grad=True, with_feature=False, with_laplace=True
            )
            _, normal_perturb, _ = self.geometry(
                pts_random + torch.randn_like(pts_random) * 1e-2,
                with_grad=True,
                with_feature=False,
                with_laplace=True,
            )
        else:
            random_sdf, random_sdf_grad = self.geometry(
                pts_random, with_grad=True, with_feature=False
            )
            _, normal_perturb = self.geometry(
                positions + torch.randn_like(positions) * 1e-2,
                with_grad=True,
                with_feature=False,
            )

        # pdb.set_trace()
        out = {
            "comp_rgb": comp_rgb,
            "comp_normal": comp_normal,
            "opacity": opacity,
            "depth": depth,
            "rays_valid": opacity > 0,
            "num_samples": torch.as_tensor(
                [len(t_starts)], dtype=torch.int32, device=rays.device
            ),
        }

        if self.training:
            out.update(
                {
                    "sdf_samples": sdf,
                    "sdf_grad_samples": sdf_grad,
                    "random_sdf": random_sdf,
                    "random_sdf_grad": random_sdf_grad,
                    "normal_perturb": normal_perturb,
                    "weights": weights.view(-1),
                    "points": midpoints.view(-1),
                    "intervals": dists.view(-1),
                    "ray_indices": ray_indices.view(-1),
                }
            )
            if self.config.geometry.grad_type == "finite_difference":
                out.update({"sdf_laplace_samples": sdf_laplace})

        if self.config.learned_background:
            out_bg = self.forward_bg_(rays)
        else:
            out_bg = {
                "comp_rgb": self.background_color[None, :].expand(*comp_rgb.shape),
                "num_samples": torch.zeros_like(out["num_samples"]),
                "rays_valid": torch.zeros_like(out["rays_valid"]),
            }

        out_full = {
            "comp_rgb": out["comp_rgb"] + out_bg["comp_rgb"] * (1.0 - out["opacity"]),
            "num_samples": out["num_samples"] + out_bg["num_samples"],
            "rays_valid": out["rays_valid"] | out_bg["rays_valid"],
        }

        return {
            **out,
            **{k + "_bg": v for k, v in out_bg.items()},
            **{k + "_full": v for k, v in out_full.items()},
        }

    def forward(self, rays):
        if self.training:
            out = self.forward_(rays)
        else:
            out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays)
        return {**out, "inv_s": self.variance.inv_s}

    def train(self, mode=True):
        self.randomized = mode and self.config.randomized
        return super().train(mode=mode)

    def eval(self):
        self.randomized = False
        return super().eval()

    def regularizations(self, out):
        losses = {}
        losses.update(self.geometry.regularizations(out))
        losses.update(self.texture.regularizations(out))
        return losses

    @torch.no_grad()
    def export(self, export_config):
        mesh = self.isosurface()
        if export_config.export_vertex_color:
            _, sdf_grad, feature = chunk_batch(
                self.geometry,
                export_config.chunk_size,
                False,
                mesh["v_pos"].to(self.rank),
                with_grad=True,
                with_feature=True,
            )
            normal = F.normalize(sdf_grad, p=2, dim=-1)
            rgb = self.texture(
                feature, -normal, normal
            )  # set the viewing directions to the normal to get "albedo"
            mesh["v_rgb"] = rgb.cpu()
        return mesh