from collections import deque
from src.energy_prediction.EnergyPredictionModel import EnergyPredictionModel
from src.energy_prediction.EnergyPredictionPipeline import EnergyPredictionPipeline
from src.vav.VAVAnomalizer import VAVAnomalizer
from src.vav.VAVPipeline import VAVPipeline
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import mqtt_client
import time
from src.rtu.RTUPipeline import RTUPipeline
from src.rtu.RTUAnomalizer1 import RTUAnomalizer1
from src.rtu.RTUAnomalizer2 import RTUAnomalizer2
import plotly.express as px
import sys
import subprocess

# subprocess.run([f"{sys.executable}", "mqttpublisher.py"])

rtu_data_pipeline = RTUPipeline(
    scaler1_path="src/rtu/models/scaler_rtu_1_2.pkl",
    scaler2_path="src/rtu/models/scaler_rtu_3_4.pkl",
)

rtu_anomalizers = []

average_energy = 0
max_energy = 0

rtu_anomalizers.append(
    RTUAnomalizer1(
        prediction_model_path="src/rtu/models/lstm_2rtu_smooth_04.keras",
        clustering_model_paths=[
            "src/rtu/models/kmeans_rtu_1.pkl",
            "src/rtu/models/kmeans_rtu_2.pkl",
        ],
        pca_model_paths=[
            "src/rtu/models/pca_rtu_1.pkl",
            "src/rtu/models/pca_rtu_2.pkl",
        ],
        num_inputs=rtu_data_pipeline.num_inputs,
        num_outputs=rtu_data_pipeline.num_outputs,
    )
)

rtu_anomalizers.append(
    RTUAnomalizer2(
        prediction_model_path="src/rtu/models/lstm_2rtu_smooth_03.keras",
        clustering_model_paths=[
            "src/rtu/models/kmeans_rtu_3.pkl",
            "src/rtu/models/kmeans_rtu_4.pkl",
        ],
        pca_model_paths=[
            "src/rtu/models/pca_rtu_3.pkl",
            "src/rtu/models/pca_rtu_4.pkl",
        ],
        num_inputs=rtu_data_pipeline.num_inputs,
        num_outputs=rtu_data_pipeline.num_outputs,
    )
)

vav_pipelines = []
vav_anomalizers = []
for i in range(1, 5):
    vav_pipelines.append(
        VAVPipeline(rtu_id=i, scaler_path=f"src/vav/models/scaler_vav_{i}.pkl")
    )

for i in range(1, 5):
    vav_anomalizers.append(
        VAVAnomalizer(
            rtu_id=i,
            prediction_model_path=f"src/vav/models/lstm_vav_0{i}.keras",
            clustering_model_path=f"src/vav/models/kmeans_vav_{i}.pkl",
            pca_model_path=f"src/vav/models/pca_vav_{i}.pkl",
            num_inputs=vav_pipelines[i - 1].num_inputs,
            num_outputs=vav_pipelines[i - 1].num_outputs,
        )
    )


all_data = pd.read_csv("bootstrap_data.csv")
df_faults = pd.DataFrame(columns=["_______Time_______", "__________Issue__________"])
current_stat = [False, False, False, False]
energy_pipeline_north = EnergyPredictionPipeline(
    scaler_path="src/energy_prediction/models/scalerNorth.pkl",
    wing="north",
    bootstrap_data=all_data,
)
energy_pipeline_south = EnergyPredictionPipeline(
    scaler_path="src/energy_prediction/models/scalerSouth.pkl",
    wing="south",
    bootstrap_data=all_data,
)

energy_prediction_model_north = EnergyPredictionModel(
    model_path=r"src/energy_prediction/models/lstm_energy_north_01.keras"
)

energy_prediction_model_south = EnergyPredictionModel(
    model_path=r"src/energy_prediction/models/lstm_energy_south_01.keras"
)


# Set the layout of the page to 'wide'
st.set_page_config(layout="wide")


# Energy data generating used in Energy Usage Over Time plot ---- REPLACE WITH ACTUAL DATA ----
def generate_energy_data():
    times = pd.date_range("2021-01-01", periods=200, freq="1min")
    energy = np.random.randn(200).cumsum()
    return pd.DataFrame({"Time": times, "Energy": energy})


# Create three columns for the header
header_row1_col1, header_row1_col2, header_row1_col3 = st.columns([0.8, 3, 1])

# Add logo to the first column of the header
with header_row1_col1:
    st.image("logo.png")

# Add title to the second column of the header
with header_row1_col2:
    st.markdown(
        "<h1 style='text-align: center;'>Building 59 - HVAC Dashboard</h1>",
        unsafe_allow_html=True,
    )

# Add Time and Date to the third column of the header
mqtt_client.start_mqtt_client()
placeholder_header_time = header_row1_col3.empty()

# Create three columns for the first row
row1_col1, row1_col2, row1_col3 = st.columns([1.1, 1, 0.75])


# Use a container for RTU Status
rtu_status_container = row1_col1.container()
rtu_status_container.markdown(
    """
    <div style="background-color:#E2F0D9;padding:1px;border-radius:5px;margin-bottom:20px">
    <h3 style="color:black;text-align:center;">RTU Status</h3>
    </div>""",
    unsafe_allow_html=True,
)

rtu_placeholders = []
rtu_columns = rtu_status_container.columns(4)

# Initial placeholder, does not update with streaming
for i in range(4):
    with rtu_columns[i]:
        placeholder = {"box": st.empty(), "sa_temp": st.empty(), "ra_temp": st.empty()}
        rtu_placeholders.append(placeholder)
        placeholder["box"].markdown(
            f"""
        <div style='background-color:#447F80;padding:3px;border-radius:5px;margin-bottom:10px'>
            <h4 style='color:black;text-align:center;'>RTU{i+1}</h4>
        </div>
        """,
            unsafe_allow_html=True,
        )
        placeholder["sa_temp"].markdown("**SA temp:**  --  °F")
        placeholder["ra_temp"].markdown("**RA temp:**  --  °F")


# Temperatures streaming and updates
def update_status_boxes(df, fault):
    for i in range(4):
        sa_temp = df[f"rtu_00{i+1}_sa_temp"].iloc[-1]
        ra_temp = df[f"rtu_00{i+1}_ra_temp"].iloc[-1]
        rtu_placeholders[i]["sa_temp"].markdown(f"**SA temp:**  {sa_temp} °F")
        rtu_placeholders[i]["ra_temp"].markdown(f"**RA temp:**  {ra_temp} °F")
        if fault[i] == 1:
            rtu_placeholders[i]["box"].markdown(
                f"""
            <div style='background-color:#ff4d4d;padding:3px;border-radius:5px;margin-bottom:10px'>
                <h4 style='color:black;text-align:center;'>RTU{i+1}</h4>
            </div>
            """,
                unsafe_allow_html=True,
            )
        elif fault[i] == 0:
            rtu_placeholders[i]["box"].markdown(
                f"""
            <div style='background-color:#447F80;padding:3px;border-radius:5px;margin-bottom:10px'>
                <h4 style='color:black;text-align:center;'>RTU{i+1}</h4>
            </div>
            """,
                unsafe_allow_html=True,
            )


# Zones
with row1_col2:
    st.markdown(
        """
    <div style="background-color:#E2F0D9;padding:1px;border-radius:5px;margin-bottom:20px">
    <h3 style="color:black;text-align:center;">Zones</h3>
    </div>""",
        unsafe_allow_html=True,
    )

    tab1, tab2, tab3, tab4 = st.tabs(["RTU 1", "RTU 2", "RTU 3", "RTU 4"])

    with tab1:

        zones_ = {36, 37, 38, 39, 40, 41, 42, 64, 65, 66, 67, 68, 69, 70}

        num_cols = 7
        rows = 2

        for i in range(rows):
            cols = st.columns(num_cols)
            if i == 0:
                for j in range(num_cols):
                    zone_number = (i + 1) * (j + 1) + 35
                    if zone_number in zones_:
                        button_html = f'<button style="width:100%; height:50px; border:none; color:black; background-color:#FFFFFF">{zone_number}</button>'
                        with cols[j]:
                            st.markdown(button_html, unsafe_allow_html=True)
                    else:
                        with cols[j]:
                            st.write("")
            else:
                for j in range(num_cols):
                    zone_number = (i + 1) * 30 + j + 4
                    if zone_number in zones_:
                        button_html = f'<button style="width:100%; height:50px; border:none; color:black; background-color:#FFFFFF">{zone_number}</button>'
                        with cols[j]:
                            st.markdown(button_html, unsafe_allow_html=True)
                    else:
                        with cols[j]:
                            st.write("")

    with tab2:
        zones_ = [
            19,
            20,
            27,
            28,
            29,
            30,
            31,
            32,
            33,
            34,
            35,
            43,
            44,
            49,
            50,
            57,
            58,
            59,
            60,
            62,
            63,
            71,
            72,
        ]
        zones_list = list(zones_)
        num_cols = 7
        rows = 4
        zones_list_rows = [
            zones_list[i * num_cols : (i + 1) * num_cols] for i in range(rows)
        ]

        for row in zones_list_rows:
            cols = st.columns(num_cols)
            for col, zone_number in zip(cols, row):
                button_html = f'<button style="width:100%; height:50px; border:none; color:black; background-color:#FFFFFF">{zone_number}</button>'
                with col:
                    st.markdown(button_html, unsafe_allow_html=True)

    with tab3:
        zones_ = [18, 25, 26, 45, 48, 55, 56, 61]
        zones_list = sorted(zones_)
        num_cols = 7
        rows = 2
        zones_list_rows = [
            zones_list[i * num_cols : (i + 1) * num_cols] for i in range(rows)
        ]

        for row in zones_list_rows:
            cols = st.columns(num_cols)
            for col, zone_number in zip(cols, row):
                button_html = f'<button style="width:100%; height:50px; border:none; color:black; background-color:#FFFFFF">{zone_number}</button>'
                with col:
                    st.markdown(button_html, unsafe_allow_html=True)

    with tab4:
        zones_ = [16, 17, 21, 22, 23, 24, 46, 47, 51, 52, 53, 54]
        zones_list = sorted(zones_)
        num_cols = 7
        rows = 2
        zones_list_rows = [
            zones_list[i * num_cols : (i + 1) * num_cols] for i in range(rows)
        ]

        for row in zones_list_rows:
            cols = st.columns(num_cols)
            for col, zone_number in zip(cols, row):
                button_html = f'<button style="width:100%; height:50px; border:none; color:black; background-color:#FFFFFF">{zone_number}</button>'
                with col:
                    st.markdown(button_html, unsafe_allow_html=True)

# Faults
with row1_col3:
    fault_placeholder = {"heading": st.empty(), "dataframe": st.empty()}
    fault_placeholder["heading"].markdown(
        """
    <div style="background-color:#E2F0D9;padding:1px;border-radius:5px;margin-bottom:20px">
    <h3 style="color:black;text-align:center;">Fault Log</h3>
    </div>""",
        unsafe_allow_html=True,
    )

    fault_placeholder["dataframe"].dataframe(df_faults)


def fault_table_update(fault, df_faults, current_stat, df_time):
    for i in range(4):
        if fault[i] == 1 and current_stat[i] == False:
            df_faults.loc[len(df_faults)] = [
                df_time,
                f"RTU_0{i+1}_fan/damper_fault - Start",
            ]
            current_stat[i] = True

        if fault[i] == 0 and current_stat[i] == True:
            df_faults.loc[len(df_faults)] = [
                df_time,
                f"RTU_0{i+1}_fan/damper_fault - End",
            ]
            current_stat[i] = False
        fault_placeholder["dataframe"].dataframe(df_faults)


# Details
with st.container():
    st.markdown(
        """
    <div style="background-color:#E2F0D9;padding:1px;border-radius:5px;margin-bottom:20px">
    <h3 style="color:black;text-align:center;">Details</h3>
    </div>""",
        unsafe_allow_html=True,
    )

    # Create three columns
    row2_row1_col1, row2_row1_col2 = st.columns([0.9, 1.5])

    # Floor Plan
    with row2_row1_col1:
        st.subheader("Floor Map")
        st.image("floor_plan.jpg", use_column_width=True)

    # Energy Comsumption Plots
    with row2_row1_col2:

        # Create two rows and two columns
        row2_row2_col1, row2_row2_col2 = st.columns(2)
        # cols = st.columns(2)

        with row2_row2_col1:
            st.subheader("Energy Usage - North Wing")
            north_wing_energy_container = st.empty()

            # with row2_row2_col2:
            st.subheader("Energy Usage - South Wing")
            south_wing_energy_container = st.empty()

            # Energy Comsumption Statistics
            with row2_row2_col2:
                energy_stats_placeholder = {"box": st.empty(), "sub": st.empty()}

                energy_stats_placeholder["box"].subheader("Energy Usage Statistics")
                energy_stats_placeholder["sub"].text(
                    f"Average: {int(average_energy)} kW\nHighest: {int(max_energy)} kW"
                )  # ---- REPLACE WITH ACTUAL DATA ----


distances = []


def create_residual_plot(resid_pca_list, distance, rtu_id, lim=8):
    if rtu_id % 2 == 1:
        ax1 = 0
        ax2 = 1
    elif rtu_id % 2 == 0:
        ax1 = 2
        ax2 = 3
    fig = px.scatter(
        x=resid_pca_list[:, ax1],
        y=resid_pca_list[:, ax2],
        color=distance,
        labels={"x": "Time", "y": "Residual"},
        width=500,
        height=500,
        color_discrete_sequence=px.colors.qualitative.Set2,
    )
    fig.update_layout(
        xaxis_range=[-lim, lim],
        yaxis_range=[-lim, lim],
        xaxis=dict(showgrid=True, gridwidth=1, gridcolor="lightgray"),
        yaxis=dict(showgrid=True, gridwidth=1, gridcolor="lightgray"),
        margin=dict(l=20, r=20, t=20, b=20),
        hovermode="closest",
        showlegend=False,
        autosize=False,
        hoverlabel=dict(bgcolor="white", font_size=12),
        hoverlabel_align="left",
        hoverlabel_font_color="black",
        hoverlabel_bordercolor="lightgray",
    )
    # fig.update_traces(marker=dict(size=5, color="blue"))

    return fig


resid_placeholder = st.empty()

resid_vav_placeholder = st.empty()

k = 0

while True:

    if mqtt_client.data_list:
        all_data = pd.concat([all_data, pd.DataFrame(mqtt_client.data_list)], axis=0)
        if len(all_data) > 10080:
            all_data = all_data.iloc[-10080:]

        df = pd.DataFrame(all_data)

        df_time = df["date"].iloc[-1]  # Obtain the latest datetime of data

        with placeholder_header_time:
            placeholder_header_time.markdown(
                f"""
            <h2 style='text-align: center;'> 🕒 {df_time}</h2>
            """,
                unsafe_allow_html=True,
            )

        # Loop to update

        dist = None
        resid_pca_list_rtu = None
        resid_pca_list_rtu_2 = None
        resid_pca_list_vav_1 = None
        resid_pca_list_vav_2 = None
        rtu_1_distance = None
        rtu_2_distance = None
        fault_1 = None
        fault_2 = None
        rtu_3_distance = None
        rtu_4_distance = None
        fault_3 = None
        fault_4 = None

        energy = (
            pd.DataFrame(mqtt_client.data_list)["hvac_N"][0].item()
            + pd.DataFrame(mqtt_client.data_list)["hvac_S"][0].item()
        )
        k += 1

        average_energy = average_energy + (energy - average_energy) / k

        if energy > max_energy:
            max_energy = energy

        energy_stats_placeholder["sub"].text(
            f"Average: {int(average_energy)} kW\nHighest: {int(max_energy)} kW"
        )  # ---- REPLACE WITH ACTUAL DATA ----

        df_new1, df_trans1, df_new2, df_trans2 = rtu_data_pipeline.fit(
            pd.DataFrame(mqtt_client.data_list)
        )

        vav_1_df_new, vav_1_df_trans = vav_pipelines[0].fit(
            pd.DataFrame(mqtt_client.data_list)
        )
        vav_anomalizers[0].num_inputs = vav_pipelines[0].num_inputs
        vav_anomalizers[0].num_outputs = vav_pipelines[0].num_outputs

        vav_2_df_new, vav_2_df_trans = vav_pipelines[1].fit(
            pd.DataFrame(mqtt_client.data_list)
        )
        vav_anomalizers[1].num_inputs = vav_pipelines[1].num_inputs
        vav_anomalizers[1].num_outputs = vav_pipelines[1].num_outputs

        vav_3_df_new, vav_3_df_trans = vav_pipelines[2].fit(
            pd.DataFrame(mqtt_client.data_list)
        )
        vav_anomalizers[2].num_inputs = vav_pipelines[2].num_inputs
        vav_anomalizers[2].num_outputs = vav_pipelines[2].num_outputs

        vav_4_df_new, vav_4_df_trans = vav_pipelines[3].fit(
            pd.DataFrame(mqtt_client.data_list)
        )
        vav_anomalizers[3].num_inputs = vav_pipelines[3].num_inputs
        vav_anomalizers[3].num_outputs = vav_pipelines[3].num_outputs

        energy_df_north = energy_pipeline_north.fit(all_data)
        energy_df_south = energy_pipeline_south.fit(all_data)

        if (
            not df_new1 is None
            and not df_trans1 is None
            and not df_new2 is None
            and not df_trans2 is None
        ):
            (
                actual_list,
                pred_list,
                resid_list,
                resid_pca_list_rtu,
                dist,
                rtu_1_distance,
                rtu_2_distance,
                fault_1,
                fault_2,
            ) = rtu_anomalizers[0].pipeline(
                df_new1, df_trans1, rtu_data_pipeline.scaler1
            )
            (
                actual_list_2,
                pred_list_2,
                resid_list_2,
                resid_pca_list_rtu_2,
                dist_2,
                rtu_3_distance,
                rtu_4_distance,
                fault_3,
                fault_4,
            ) = rtu_anomalizers[1].pipeline(
                df_new2, df_trans2, rtu_data_pipeline.scaler2
            )
        if not vav_1_df_new is None:
            (
                actual_list_vav_1,
                pred_list_vav_1,
                resid_list_vav_1,
                resid_pca_list_vav_1,
                dist_vav_1,
            ) = vav_anomalizers[0].pipeline(
                vav_1_df_new, vav_1_df_trans, vav_pipelines[0].scaler
            )

        if not vav_2_df_new is None:
            (
                actual_list_vav_2,
                pred_list_vav_2,
                resid_list_vav_2,
                resid_pca_list_vav_2,
                dist_vav_2,
            ) = vav_anomalizers[1].pipeline(
                vav_2_df_new, vav_2_df_trans, vav_pipelines[1].scaler
            )

        if not vav_3_df_new is None:
            (
                actual_list_vav_3,
                pred_list_vav_3,
                resid_list_vav_3,
                resid_pca_list_vav_3,
                dist_vav_3,
            ) = vav_anomalizers[2].pipeline(
                vav_3_df_new, vav_3_df_trans, vav_pipelines[2].scaler
            )

        if not vav_4_df_new is None:
            (
                actual_list_vav_4,
                pred_list_vav_4,
                resid_list_vav_4,
                resid_pca_list_vav_4,
                dist_vav_4,
            ) = vav_anomalizers[3].pipeline(
                vav_4_df_new, vav_4_df_trans, vav_pipelines[3].scaler
            )

        if resid_pca_list_rtu is not None:
            resid_pca_list_rtu = np.array(resid_pca_list_rtu)
            resid_pca_list_rtu_2 = np.array(resid_pca_list_rtu_2)

        if resid_pca_list_rtu is not None:  # Plot RTU residuals
            with resid_placeholder.container():
                resid_rtu1_placeholder, resid_rtu2_placeholder = st.columns(2)
                with resid_rtu1_placeholder:
                    st.subheader("RTU 1 Residuals")
                    fig = create_residual_plot(
                        resid_pca_list_rtu, rtu_1_distance, rtu_id=1
                    )
                    st.plotly_chart(fig)

                with resid_rtu2_placeholder:
                    st.subheader("RTU 2 Residuals")
                    fig = create_residual_plot(
                        resid_pca_list_rtu, rtu_2_distance, rtu_id=2
                    )
                    st.plotly_chart(fig)

                resid_rtu3_placeholder, resid_rtu4_placeholder = st.columns(2)
                with resid_rtu3_placeholder:
                    st.subheader("RTU 3 Residuals")
                    fig = create_residual_plot(
                        resid_pca_list_rtu, rtu_3_distance, rtu_id=3
                    )
                    st.plotly_chart(fig)

                with resid_rtu4_placeholder:
                    st.subheader("RTU 4 Residuals")
                    fig = create_residual_plot(
                        resid_pca_list_rtu, rtu_4_distance, rtu_id=4
                    )
                    st.plotly_chart(fig)

        if resid_pca_list_vav_1 is not None:  # Plot VAV residuals

            with resid_vav_placeholder.container():
                resid_rtu_1_vav_placeholder, resid_rtu_2_vav_placeholder = st.columns(2)

                with resid_rtu_1_vav_placeholder:
                    st.subheader("VAV 1 Residuals")
                    fig = create_residual_plot(
                        np.array(resid_pca_list_vav_1), rtu_4_distance, rtu_id=1, lim=15
                    )
                    st.plotly_chart(fig)

                with resid_rtu_2_vav_placeholder:
                    st.subheader("VAV 2 Residuals")
                    fig = create_residual_plot(
                        np.array(resid_pca_list_vav_2), rtu_4_distance, rtu_id=1, lim=15
                    )
                    st.plotly_chart(fig)

                resid_rtu_3_vav_placeholder, resid_rtu_4_vav_placeholder = st.columns(2)
                with resid_rtu_3_vav_placeholder:
                    st.subheader("VAV 3 Residuals")
                    fig = create_residual_plot(
                        np.array(resid_pca_list_vav_3), rtu_4_distance, rtu_id=1, lim=15
                    )
                    st.plotly_chart(fig)

                with resid_rtu_4_vav_placeholder:
                    st.subheader("VAV 4 Residuals")
                    fig = create_residual_plot(
                        np.array(resid_pca_list_vav_4), rtu_4_distance, rtu_id=1, lim=15
                    )
                    st.plotly_chart(fig)

        current_time = pd.to_datetime(df_time)

        if energy_df_north is not None:

            energy_prediction_north = energy_prediction_model_north.pipeline(
                energy_df_north, energy_pipeline_north.scaler
            ).flatten()

            x_time = pd.date_range(
                current_time, periods=len(energy_prediction_north), freq="1h"
            )

            with north_wing_energy_container:

                fig = px.line(
                    x=x_time,
                    y=energy_prediction_north,
                    labels={"x": "Time", "y": "Energy (kW)"},
                    height=200,
                )

                st.plotly_chart(fig)

        if energy_df_south is not None:
            energy_prediction_south = energy_prediction_model_south.pipeline(
                energy_df_south, energy_pipeline_south.scaler
            ).flatten()

            x_time = pd.date_range(
                current_time, periods=len(energy_prediction_south), freq="1h"
            )

            with south_wing_energy_container:

                fig = px.line(
                    x=x_time,
                    y=energy_prediction_south,
                    labels={"x": "Time", "y": "Energy (kWh)"},
                    height=200,
                )

                st.plotly_chart(fig)

        update_status_boxes(df, [fault_1, fault_2, fault_3, fault_4])
        fault_table_update(
            [fault_1, fault_2, fault_3, fault_4], df_faults, current_stat, df_time
        )
        mqtt_client.data_list.clear()