from __future__ import annotations

import base64
import gzip
import json
from dataclasses import dataclass, fields
from io import BytesIO
from pathlib import Path
from urllib.parse import parse_qsl

import altair as alt
import ipywidgets as widgets
import numpy as np
import polars as pl
import solara
import solara.lab
from cmap import Colormap
from ipymolstar.widget import PDBeMolstar
from pydantic import BaseModel

from make_link import decode_data

base_v = np.vectorize(np.base_repr)
PAD_SIZE = 0.05  # when not autoscale Y size of padding used


def norm(x, vmin, vmax):
    return (x - vmin) / (vmax - vmin)


class ColorTransform(BaseModel):
    name: str = "tol:rainbow_PuRd"
    norm_type: str = "linear"
    vmin: float = 0.0
    vmax: float = 1.0
    missing_data_color: str = "#8c8c8c"
    highlight_color: str = "#e933f8"

    def molstar_colors(self, data: pl.DataFrame) -> dict:
        data = data.drop_nulls()
        if self.norm_type == "categorical":
            values = data["value"]
        else:
            values = norm(data["value"], vmin=self.vmin, vmax=self.vmax)

        rgba_array = self.cmap(values, bytes=True)
        ints = rgba_array.astype(np.uint8).view(dtype=np.uint32).byteswap()
        padded = np.char.rjust(base_v(ints // 2**8, 16), 6, "0")
        hex_colors = np.char.add("#", padded).squeeze()

        color_data = {
            "data": [
                {"residue_number": resi, "color": hcolor.lower()}
                for resi, hcolor in zip(data["residue_number"], hex_colors)
            ],
            "nonSelectedColor": self.missing_data_color,
        }

        return color_data

    @property
    def cmap(self) -> Colormap:
        return Colormap(self.name, bad=self.missing_data_color)

    @property
    def altair_scale(self) -> alt.Scale:
        if self.norm_type == "categorical":
            colors = self.cmap.to_altair(N=self.cmap.num_colors)
            domain = range(self.cmap.num_colors)
        else:
            colors = self.cmap.to_altair()
            domain = np.linspace(self.vmin, self.vmax, 256, endpoint=True)

        scale = alt.Scale(domain=list(domain), range=colors, clamp=True)
        return scale


class AxisProperties(BaseModel):
    label: str = "x"
    unit: str = "au"
    autoscale_y: bool = True

    @property
    def title(self) -> str:
        return f"{self.label} ({self.unit})"


def make_chart(
    data: pl.DataFrame, colors: ColorTransform, axis_properties: AxisProperties
) -> alt.LayerChart:
    xmin, xmax = data["residue_number"].min(), data["residue_number"].max()
    xpad = (xmax - xmin) * 0.05
    xscale = alt.Scale(domain=(xmin - xpad, xmax + xpad))

    if axis_properties.autoscale_y:
        y_scale = alt.Scale()
    elif colors.norm_type == "categorical":
        ypad = colors.cmap.num_colors * 0.05
        y_scale = alt.Scale(domain=(0 - ypad, colors.cmap.num_colors - 1 + ypad))
    else:
        ypad = (colors.vmax - colors.vmin) * 0.05
        y_scale = alt.Scale(domain=(colors.vmin - ypad, colors.vmax + ypad))

    zoom_x = alt.selection_interval(
        bind="scales",
        encodings=["x"],
        zoom="wheel![!event.shiftKey]",
    )

    scatter = (
        alt.Chart(data)
        .mark_circle(interpolate="basis", size=200)
        .encode(
            x=alt.X("residue_number:Q", title="Residue Number", scale=xscale),
            y=alt.Y(
                "value:Q",
                title=axis_properties.title,
                scale=y_scale,
            ),
            color=alt.Color(
                f"value:{'O' if colors.norm_type == 'categorical' else 'Q'}",
                scale=colors.altair_scale,
                title=axis_properties.title,
            ),
        )
        .add_params(zoom_x)
    )

    # Create a selection that chooses the nearest point & selects based on x-value
    nearest = alt.selection_point(
        name="point",
        nearest=True,
        on="pointerover",
        fields=["residue_number"],
        empty=False,
        clear="mouseout",
    )

    select_residue = (
        alt.Chart(data)
        .mark_point()
        .encode(
            x="residue_number:Q",
            opacity=alt.value(0),
        )
        .add_params(nearest)
    )

    # Draw a rule at the location of the selection
    rule = (
        alt.Chart(data)
        .mark_rule(color=colors.highlight_color, size=2)
        .encode(
            x="residue_number:Q",
        )
        .transform_filter(nearest)
    )

    # vline = (
    #     alt.Chart(pd.DataFrame({"x": [0]}))
    #     .mark_rule(color=colors.highlight_color, size=2)
    #     .encode(x="x:Q")
    # )

    line_position = alt.param(name="line_position", value=0.0)
    line_opacity = alt.param(name="line_opacity", value=1)
    df_line = pl.DataFrame({"x": [1.0]})

    # Create vertical rule with parameter
    vline = (
        alt.Chart(df_line)
        .mark_rule(color=colors.highlight_color, opacity=line_opacity, size=2)
        .encode(x=alt.X("p", type="quantitative"))
        .transform_calculate(p=alt.datum.x * line_position)
        .add_params(line_position, line_opacity)
    )

    # Put the five layers into a chart and bind the data
    chart = (
        alt.layer(scatter, vline, select_residue, rule).properties(
            width="container",
            height=480,  # autosize height?
        )
        # .configure(autosize="fit")
    )

    return chart


@solara.component
def ScatterChart(
    data: pl.DataFrame,
    colors: ColorTransform,
    axis_properties: AxisProperties,
    on_selections,
    line_value,
):
    def mem_chart():
        chart = make_chart(data, colors, axis_properties)
        return chart

    chart = solara.use_memo(mem_chart, dependencies=[data, colors, axis_properties])

    if line_value is not None:
        params = {"line_position": line_value, "line_opacity": 1}
    else:
        params = {"line_position": 0.0, "line_opacity": 0}
    dark_effective = solara.lab.use_dark_effective()
    if dark_effective:
        options = {"actions": False, "theme": "dark"}
    else:
        options = {"actions": False}

    view = alt.JupyterChart.element(  # type: ignore
        chart=chart,
        embed_options=options,
        _params=params,
    )

    def bind():
        real = solara.get_widget(view)
        real.selections.observe(on_selections, "point")  # type: ignore

    solara.use_effect(bind, [data, colors])


def is_numeric(val) -> bool:
    if val is not None:
        return not np.isnan(val)
    return False


@solara.component
def ProteinView(
    title: str,
    molecule_id: str,
    data: pl.DataFrame,
    colors: ColorTransform,
    axis_properties: AxisProperties,
    dark_effective: bool,
    description: str = "",
):
    about_dialog = solara.use_reactive(False)
    fullscreen = solara.use_reactive(False)

    # residue number to highlight in altair chart
    line_number = solara.use_reactive(None)

    # residue number to highlight in protein view
    highlight_number = solara.use_reactive(None)

    if data.is_empty():
        color_data = {}
    else:
        color_data = colors.molstar_colors(data)

    tooltips = {
        "data": [
            {
                "residue_number": resi,
                "tooltip": f"{axis_properties.label}: {value:.2g} {axis_properties.unit}"
                if is_numeric(value)
                else "No data",
            }
            for resi, value in zip(data["residue_number"], data["value"])
        ]
    }

    def on_molstar_mouseover(value):
        r = value.get("residueNumber", None)
        line_number.set(r)

    def on_molstar_mouseout(value):
        on_molstar_mouseover({})

    def on_chart_selection(event):
        try:
            r = event["new"].value[0]["residue_number"]
            highlight_number.set(r)
        except (IndexError, KeyError):
            highlight_number.set(None)

    with solara.AppBar():
        solara.AppBarTitle(title)
        with solara.Tooltip("Fullscreen"):
            solara.Button(
                icon_name="mdi-fullscreen",
                icon=True,
                on_click=lambda: fullscreen.set(not fullscreen.value),
            )
        if description:
            with solara.Tooltip("About"):
                solara.Button(
                    icon_name="mdi-information-outline",
                    icon=True,
                    on_click=lambda: about_dialog.set(True),
                )
        solara.lab.ThemeToggle()

    with solara.v.Dialog(
        v_model=about_dialog.value, on_v_model=lambda _ignore: about_dialog.set(False)
    ):
        with solara.Card("About", margin=0):
            solara.Markdown(description)

    with solara.ColumnsResponsive([4, 8]):
        with solara.Card(style={"height": "550px"}):
            PDBeMolstar.element(  # type: ignore
                theme="dark" if dark_effective else "light",
                molecule_id=molecule_id.lower(),
                color_data=color_data,
                hide_water=True,
                tooltips=tooltips,
                height="525px",
                highlight={"data": [{"residue_number": int(highlight_number.value)}]}
                if highlight_number.value
                else None,
                highlight_color=colors.highlight_color,
                on_mouseover_event=on_molstar_mouseover,
                on_mouseout_event=on_molstar_mouseout,
                hide_controls_icon=True,
                hide_expand_icon=True,
                hide_settings_icon=True,
                expanded=fullscreen.value,
            ).key(f"molstar-{dark_effective}")
        if not fullscreen.value:
            with solara.Card(style={"height": "550px"}):
                if data.is_empty():
                    solara.Text("No data")
                else:
                    ScatterChart(
                        data,
                        colors,
                        axis_properties,
                        on_chart_selection,
                        line_number.value,
                    )


@solara.component
def RoutedView():
    route = solara.use_router()
    dark_effective = solara.lab.use_dark_effective()

    try:
        query_dict = {k: v for k, v in parse_qsl(route.search)}
        colors = ColorTransform(**query_dict)  # type: ignore
        axis_properties = AxisProperties(**query_dict)  # type: ignore
        data = decode_data(query_dict["data"])
        ProteinView(
            query_dict["title"],
            molecule_id=query_dict["molecule_id"],
            data=data,
            colors=colors,
            axis_properties=axis_properties,
            dark_effective=dark_effective,
            description=query_dict.get("description", ""),
        )
    except KeyError as err:
        solara.Warning(f"Error: {err}")