import math

import jax
from flax.core import unfreeze, freeze
import jax.numpy as jnp
import flax.linen as nn
from jaxtyping import Array, ArrayLike, PyTree

from .edsr import EDSR
from .rdn import RDN
from .hyper import Hypernetwork
from .tail import build_tail
from .init import uniform_between, linear_up
from utils import make_grid, interpolate_grid, repeat_vmap


class Thermal(nn.Module):
    w0_scale: float = 1.

    @nn.compact
    def __call__(self, x: ArrayLike, t, norm, k) -> Array:
        phase = self.param('phase', nn.initializers.uniform(.5), x.shape[-1:])
        return jnp.sin(self.w0_scale * x + phase) * jnp.exp(-(self.w0_scale * norm)**2 * k * t)


class TheraField(nn.Module):
    dim_hidden: int
    dim_out: int
    w0: float = 1.
    c: float = 6.

    @nn.compact
    def __call__(self, x: ArrayLike, t: ArrayLike, k: ArrayLike, components: ArrayLike) -> Array:
        # coordinate projection according to shared components ("first layer")
        x = x @ components

        # thermal activations
        norm = jnp.linalg.norm(components, axis=-2)
        x = Thermal(self.w0)(x, t, norm, k)

        # linear projection from hidden to output space ("second layer")
        w_std = math.sqrt(self.c / self.dim_hidden) / self.w0
        dense_init_fn = uniform_between(-w_std, w_std)
        x = nn.Dense(self.dim_out, kernel_init=dense_init_fn, use_bias=False)(x)

        return x


class Thera:

    def __init__(
            self,
            hidden_dim: int,
            out_dim: int,
            backbone: nn.Module,
            tail: nn.Module,
            k_init: float = None,
            components_init_scale: float = None
    ):
        self.hidden_dim = hidden_dim
        self.k_init = k_init
        self.components_init_scale = components_init_scale

        # single TheraField object whose `apply` method is used for all grid cells
        self.field = TheraField(hidden_dim, out_dim)

        # infer output size of the hypernetwork from a sample pass through the field;
        # key doesnt matter as field params are only used for size inference
        sample_params = self.field.init(jax.random.PRNGKey(0),
            jnp.zeros((2,)), 0., 0., jnp.zeros((2, hidden_dim)))
        sample_params_flat, tree_def = jax.tree_util.tree_flatten(sample_params)
        param_shapes = [p.shape for p in sample_params_flat]

        self.hypernet = Hypernetwork(backbone, tail, param_shapes, tree_def)

    def init(self, key, sample_source) -> PyTree:
        keys = jax.random.split(key, 2)
        sample_coords = jnp.zeros(sample_source.shape[:-1] + (2,))
        params = unfreeze(self.hypernet.init(keys[0], sample_source, sample_coords))

        params['params']['k'] = jnp.array(self.k_init)
        params['params']['components'] = \
            linear_up(self.components_init_scale)(keys[1], (2, self.hidden_dim))

        return freeze(params)

    def apply_encoder(self, params: PyTree, source: ArrayLike, **kwargs) -> Array:
        """
        Performs a forward pass through the hypernetwork to obtain an encoding.
        """
        return self.hypernet.apply(
            params, source, method=self.hypernet.get_encoding, **kwargs)

    def apply_decoder(
        self,
        params: PyTree,
        encoding: ArrayLike,
        coords: ArrayLike,
        t: ArrayLike,
        return_jac: bool = False
    ) -> Array | tuple[Array, Array]:
        """
        Performs a forward prediction through a grid of HxW Thera fields,
        informed by `encoding`, at spatial and temporal coordinates
        `coords` and `t`, respectively.
        args:
            params: Field parameters, shape (B, H, W, N)
            encoding: Encoding tensor, shape (B, H, W, C)
            coords: Spatial coordinates in [-0.5, 0.5], shape (B, H, W, 2)
            t: Temporal coordinates, shape (B, 1)
        """
        phi_params: PyTree = self.hypernet.apply(
            params, encoding, coords, method=self.hypernet.get_params_at_coords)

        # create local coordinate systems
        source_grid = jnp.asarray(make_grid(encoding.shape[-3:-1]))
        source_coords = jnp.tile(source_grid, (encoding.shape[0], 1, 1, 1))
        interp_coords = interpolate_grid(coords, source_coords)
        rel_coords = (coords - interp_coords)
        rel_coords = rel_coords.at[..., 0].set(rel_coords[..., 0] * encoding.shape[-3])
        rel_coords = rel_coords.at[..., 1].set(rel_coords[..., 1] * encoding.shape[-2])

        # three maps over params, coords; one over t; dont map k and components
        in_axes = [(0, 0, None, None, None), (0, 0, None, None, None), (0, 0, 0, None, None)]
        apply_field = repeat_vmap(self.field.apply, in_axes)
        out = apply_field(phi_params, rel_coords, t, params['params']['k'],
            params['params']['components'])

        if return_jac:
            apply_jac = repeat_vmap(jax.jacrev(self.field.apply, argnums=1), in_axes)
            jac = apply_jac(phi_params, rel_coords, jnp.zeros_like(t), params['params']['k'],
                params['params']['components'])
            return out, jac

        return out

    def apply(
        self,
        params: ArrayLike,
        source: ArrayLike,
        coords: ArrayLike,
        t: ArrayLike,
        return_jac: bool = False,
        **kwargs
    ) -> Array:
        """
        Performs a forward pass through the Thera model.
        """
        encoding = self.apply_encoder(params, source, **kwargs)
        out = self.apply_decoder(params, encoding, coords, t, return_jac=return_jac)
        return out


def build_thera(
    out_dim: int,
    backbone: str,
    size: str,
    k_init: float = None,
    components_init_scale: float = None
):
    """
    Convenience function for building the three Thera sizes described in the paper.
    """
    hidden_dim = 32 if size == 'air' else 512

    if backbone == 'edsr-baseline':
        backbone_module = EDSR(None, num_blocks=16, num_feats=64)
    elif backbone == 'rdn':
        backbone_module = RDN()
    else:
        raise NotImplementedError(backbone)

    tail_module = build_tail(size)

    return Thera(hidden_dim, out_dim, backbone_module, tail_module, k_init, components_init_scale)