import os
import time
import httpx
import string
import random
import datetime as dt
from dotenv import load_dotenv

import streamlit as st
import extra_streamlit_components as stx

import asyncio
from aiocache import cached, Cache

import pandas as pd
from typing import Optional, Callable

from config import ENV_PATH, BEST_MODELS, TEST_FILE, TEST_FILE_URL, HISTORY_FILE, markdown_table_all

from utils.navigation import navigation
from utils.footer import footer
from utils.janitor import Janitor


# Load ENV
load_dotenv(ENV_PATH)  # API_URL

# Set page configuration
st.set_page_config(
    page_title="Homepage",
    page_icon="🤖",
    layout="wide",
    initial_sidebar_state='auto'
)


@cached(ttl=10, cache=Cache.MEMORY, namespace='streamlit_savedataset')
# @st.cache_data(show_spinner="Saving datasets...") # Streamlit cache is yet to support async functions
async def save_dataset(df: pd.DataFrame, filepath, csv=True) -> None:
    async def save(df: pd.DataFrame, file):
        return df.to_csv(file, index=False) if csv else df.to_excel(file, index=False)

    async def read(file):
        return pd.read_csv(file) if csv else pd.read_excel(file)

    async def same_dfs(df: pd.DataFrame, df2: pd.DataFrame):
        return df.equals(df2)

    if not os.path.isfile(filepath):  # Save if file does not exists
        await save(df, filepath)
    else:  # Save if data are not same
        df_old = await read(filepath)
        if not await same_dfs(df, df_old):
            await save(df, filepath)


@cached(ttl=10, cache=Cache.MEMORY, namespace='streamlit_testdata')
async def get_test_data():
    try:
        df_test_raw = pd.read_csv(TEST_FILE_URL)
        await save_dataset(df_test_raw, TEST_FILE, csv=True)
    except Exception:
        df_test_raw = pd.read_csv(TEST_FILE)

    # Some house keeping, clean df
    df_test = df_test_raw.copy()
    janitor = Janitor()
    df_test = janitor.clean_dataframe(df_test)  # Cleaned

    return df_test_raw, df_test


# Function for selecting models
async def select_model() -> str:
    col1, _ = st.columns(2)
    with col1:
        selected_model = st.selectbox(
            'Select a model', options=BEST_MODELS, key='selected_model')

    return selected_model


async def endpoint(model: str) -> str:
    api_url = os.getenv("API_URL")
    model_endpoint = f"{api_url}={model}"
    return model_endpoint


# Function for making prediction
async def make_prediction(model_endpoint) -> Optional[pd.DataFrame]:

    test_data = await get_test_data()
    _, df_test = test_data

    df: pd.DataFrame = None
    search_patient = st.session_state.get('search_patient', False)
    search_patient_id = st.session_state.get('search_patient_id', False)
    manual_patient_id = st.session_state.get('manual_patient_id', False)
    if isinstance(search_patient_id, str) and search_patient_id:  # And not empty string
        search_patient_id = [search_patient_id]
    if search_patient and search_patient_id:  # Search Form df and a patient was selected
        mask = df_test['id'].isin(search_patient_id)
        df_form = df_test[mask]
        df = df_form.copy()
    elif not (search_patient or search_patient_id) and manual_patient_id:  # Manual form df
        columns = ['manual_patient_id', 'prg', 'pl', 'pr', 'sk',
                   'ts', 'm11', 'bd2', 'age', 'insurance']
        data = {c: [st.session_state.get(c)] for c in columns}
        data['insurance'] = [1 if i == 'Yes' else 0 for i in data['insurance']]

        # Make a DataFrame
        df = pd.DataFrame(data).rename(
            columns={'manual_patient_id': 'id'})
        columns_int = ['prg', 'pl', 'pr', 'sk', 'ts', 'age']
        columns_float = ['m11', 'bd2']

        df[columns_int] = df[columns_int].astype(int)
        df[columns_float] = df[columns_float].astype(float)
    else:  # Form did not send a patient
        message = 'You must choose valid patient(s) from the select box.'
        icon = '😞'
        st.toast(message, icon=icon)
        st.warning(message, icon=icon)

    if df is not None:
        try:
            # JSON data
            data = df.to_dict(orient='list')

            # Send POST request with JSON data using the json parameter
            async with httpx.AsyncClient() as client:
                response = await client.post(model_endpoint, json=data, timeout=30)
                response.raise_for_status()  # Ensure we catch any HTTP errors

            if (response.status_code == 200):
                pred_prob = (response.json()['result'])
                prediction = pred_prob['prediction'][0]
                probability = pred_prob['probability'][0]

                # Store results in session state
                st.session_state['prediction'] = prediction
                st.session_state['probability'] = probability
                df['prediction'] = prediction
                df['probability (%)'] = probability
                df['time_of_prediction'] = pd.Timestamp(dt.datetime.now())
                df['model_used'] = st.session_state['selected_model']

                df.to_csv(HISTORY_FILE, mode='a',
                          header=not os.path.isfile(HISTORY_FILE))
        except Exception as e:
            st.error(f'😞 Unable to connect to the API server. {e}')

    return df


async def convert_string(df: pd.DataFrame, string: str) -> str:
    return string.upper() if all(col.isupper() for col in df.columns) else string


async def make_predictions(model_endpoint, df_uploaded=None, df_uploaded_clean=None) -> Optional[pd.DataFrame]:

    df: pd.DataFrame = None
    search_patient = st.session_state.get('search_patient', False)
    patient_id_bulk = st.session_state.get('patient_id_bulk', False)
    upload_bulk_predict = st.session_state.get('upload_bulk_predict', False)
    if search_patient and patient_id_bulk:  # Search Form df and a patient was selected
        _, df_test = await get_test_data()
        mask = df_test['id'].isin(patient_id_bulk)
        df_bulk: pd.DataFrame = df_test[mask]
        df = df_bulk.copy()

    elif not (search_patient or patient_id_bulk) and upload_bulk_predict:  # Upload widget df
        df = df_uploaded_clean.copy()
    else:  # Form did not send a patient
        message = 'You must choose valid patient(s) from the select box.'
        icon = '😞'
        st.toast(message, icon=icon)
        st.warning(message, icon=icon)

    if df is not None:  # df should be set by form input or upload widget
        try:
            # JSON data
            data = df.to_dict(orient='list')

            # Send POST request with JSON data using the json parameter
            async with httpx.AsyncClient() as client:
                response = await client.post(model_endpoint, json=data, timeout=30)
                response.raise_for_status()  # Ensure we catch any HTTP errors

            if (response.status_code == 200):
                pred_prob = (response.json()['result'])
                predictions = pred_prob['prediction']
                probabilities = pred_prob['probability']

                # Add columns sepsis, probability, time, and model used to uploaded df and form df

                async def add_columns(df):
                    df[await convert_string(df, 'sepsis')] = predictions
                    df[await convert_string(df, 'probability_(%)')] = probabilities
                    df[await convert_string(df, 'time_of_prediction')
                       ] = pd.Timestamp(dt.datetime.now())
                    df[await convert_string(df, 'model_used')
                       ] = st.session_state['selected_model']

                    return df

                # Form df if search patient is true or df from Uploaded data
                if search_patient:
                    df = await add_columns(df)

                    df.to_csv(HISTORY_FILE, mode='a', header=not os.path.isfile(
                        HISTORY_FILE))  # Save only known patients

                else:
                    df = await add_columns(df_uploaded)  # Raw, No cleaning

                # Store df with prediction results in session state
                st.session_state['bulk_prediction_df'] = df
        except Exception as e:
            st.error(f'😞 Unable to connect to the API server. {e}')

    return df


def on_click(func: Callable, model_endpoint: str):
    async def handle_click():
        await func(model_endpoint)

    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    loop.run_until_complete(handle_click())
    loop.close()


async def search_patient_form(model_endpoint: str) -> None:
    test_data = await get_test_data()
    _, df_test = test_data

    patient_ids = df_test['id'].unique().tolist()+['']
    if st.session_state['sidebar'] == 'single_prediction':
        with st.form('search_patient_id_form'):
            col1, _ = st.columns(2)
            with col1:
                st.write('#### Patient ID 🤒')
                st.selectbox(
                    'Search a patient', options=patient_ids, index=len(patient_ids)-1, key='search_patient_id')
            st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
                func=make_prediction, model_endpoint=model_endpoint))
    else:
        with st.form('search_patient_id_bulk_form'):
            col1, _ = st.columns(2)
            with col1:
                st.write('#### Patient ID 🤒')
                st.multiselect(
                    'Search a patient', options=patient_ids, default=None, key='patient_id_bulk')
            st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
                func=make_predictions, model_endpoint=model_endpoint))


async def gen_random_patient_id() -> str:
    numbers = ''.join(random.choices(string.digits, k=6))
    letters = ''.join(random.choices(string.ascii_lowercase, k=4))
    return f"ICU{numbers}-gen-{letters}"


async def manual_patient_form(model_endpoint) -> None:
    with st.form('manual_patient_form'):

        col1, col2, col3 = st.columns(3)

        with col1:
            st.write('### Patient Demographics 🛌')
            st.text_input(
                'ID', value=await gen_random_patient_id(), key='manual_patient_id')
            st.number_input('Age: patients age (years)', min_value=0,
                            max_value=100, step=1, key='age')
            st.selectbox('Insurance: If a patient holds a valid insurance card', options=[
                'Yes', 'No'], key='insurance')

        with col2:
            st.write('### Vital Signs 🩺')
            st.number_input('BMI (weight in kg/(height in m)^2', min_value=10.0,
                            format="%.2f", step=1.00, key='m11')
            st.number_input(
                'Blood Pressure (mm Hg)', min_value=10.0, format="%.2f", step=1.00, key='pr')
            st.number_input(
                'PRG (plasma glucose)', min_value=10.0, format="%.2f", step=1.00, key='prg')

        with col3:
            st.write('### Blood Work 💉')
            st.number_input(
                'PL: Blood Work Result-1 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='pl')
            st.number_input(
                'SK: Blood Work Result 2 (mm)', min_value=10.0, format="%.2f", step=1.00, key='sk')
            st.number_input(
                'TS: Blood Work Result-3 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='ts')
            st.number_input(
                'BD2: Blood Work Result-4 (mu U/ml)', min_value=10.0, format="%.2f", step=1.00, key='bd2')

        st.form_submit_button('Predict', type='primary', on_click=on_click, kwargs=dict(
            func=make_prediction, model_endpoint=model_endpoint))


async def do_single_prediction(model_endpoint: str) -> None:
    if st.session_state.get('search_patient', False):
        await search_patient_form(model_endpoint)
    else:
        await manual_patient_form(model_endpoint)


async def show_prediction() -> None:
    final_prediction = st.session_state.get('prediction', None)
    final_probability = st.session_state.get('probability', None)

    if final_prediction is None:
        st.markdown('#### Prediction will show below! 🔬')
        st.divider()
    else:
        st.markdown('#### Prediction! 🔬')
        st.divider()
        if final_prediction.lower() == 'positive':
            st.toast("Sepsis alert!", icon='🦠')
            message = f"It is **{final_probability:.2f} %** likely that the patient will develop **sepsis.**"
            st.warning(message, icon='😞')
            time.sleep(5)
            st.toast(message)
        else:
            st.toast("Continous monitoring", icon='🔬')
            message = f"The patient will **not** develop sepsis with a likelihood of **{final_probability:.2f}%**."
            st.success(message, icon='😊')
            time.sleep(1)
            st.toast(message)

    # Set prediction and probability to None
    st.session_state['prediction'] = None
    st.session_state['probability'] = None


# @st.cache_data(show_spinner=False) Caching results from async functions buggy
async def convert_df(df: pd.DataFrame):
    return df.to_csv(index=False)


async def bulk_upload_widget(model_endpoint: str) -> None:
    uploaded_file = st.file_uploader(
        "Choose a CSV or Excel File", type=['csv', 'xls', 'xlsx'])

    uploaded = uploaded_file is not None

    upload_bulk_predict = st.button('Predict', type='primary',
                                    help='Upload a csv/excel file to make predictions', disabled=not uploaded, key='upload_bulk_predict')
    df = None
    if upload_bulk_predict and uploaded:
        df_test_raw, _ = await get_test_data()
        # Uploadfile is a "file-like" object is accepted
        try:
            try:
                df = pd.read_csv(uploaded_file)
            except Exception:
                df = pd.read_excel(uploaded_file)

            df_columns = set(df.columns)
            df_test_columns = set(df_test_raw.columns)
            df_schema = df.dtypes
            df_test_schema = df_test_raw.dtypes

            if df_columns != df_test_columns or not df_schema.equals(df_test_schema):
                df = None
                raise Exception
            else:
                # Clean dataframe
                janitor = Janitor()
                df_clean = janitor.clean_dataframe(df)

                df = await make_predictions(
                    model_endpoint, df_uploaded=df, df_uploaded_clean=df_clean)

        except Exception:
            st.subheader('Data template')
            data_template = df_test_raw[:3]
            st.dataframe(data_template)
            csv = await convert_df(data_template)
            message_1 = 'Upload a valid csv or excel file.'
            message_2 = f"{message_1.split('.')[0]} with the columns and schema of the above data template."
            icon = '😞'
            st.toast(message_1, icon=icon)

            st.download_button(
                label='Download template',
                data=csv,
                file_name='Data template.csv',
                mime="text/csv",
                type='secondary',
                key='download-data-template'
            )
            st.info('Download the above template for use as a baseline structure.')

            # Display explander to show the data dictionary
            with st.expander("Expand to see the data dictionary", icon="💡"):
                st.subheader("Data dictionary")
                st.markdown(markdown_table_all)
            st.warning(message_2, icon=icon)

    return df


async def do_bulk_prediction(model_endpoint: str) -> None:
    if st.session_state.get('search_patient', False):
        await search_patient_form(model_endpoint)
    else:
        # File uploader
        await bulk_upload_widget(model_endpoint)


async def show_bulk_predictions(df: pd.DataFrame) -> None:
    if df is not None:
        st.subheader("Bulk predictions 🔮", divider=True)
        st.dataframe(df.astype(str))

        csv = await convert_df(df)
        message = 'The predictions are ready for download.'
        icon = '⬇️'
        st.toast(message, icon=icon)
        st.info(message, icon=icon)
        st.download_button(
            label='Download predictions',
            data=csv,
            file_name='Bulk prediction.csv',
            mime="text/csv",
            type='secondary',
            key='download-bulk-prediction'
        )

        # Set bulk prediction df to None
        st.session_state['bulk_prediction_df'] = None


async def sidebar(sidebar_type: str) -> st.sidebar:
    return st.session_state.update({'sidebar': sidebar_type})


async def main():
    st.title("🤖 Predict Sepsis 🦠")

    # Navigation
    await navigation()

    st.sidebar.toggle("Looking for a patient?", value=st.session_state.get(
        'search_patient', False), key='search_patient')

    selected_model = await select_model()
    model_endpoint = await endpoint(selected_model)

    selected_predict_tab = st.session_state.get('selected_predict_tab')
    default = 1 if selected_predict_tab is None else selected_predict_tab

    with st.spinner('A little house keeping...'):
        time.sleep(st.session_state.get('sleep', 1.5))
        chosen_id = stx.tab_bar(data=[
            stx.TabBarItemData(id=1, title='🔬 Predict', description=''),
            stx.TabBarItemData(id=2, title='🔮 Bulk predict',
                               description=''),
        ], default=default)
        st.session_state['sleep'] = 0

    if chosen_id == '1':
        await sidebar('single_prediction')
        await do_single_prediction(model_endpoint)
        await show_prediction()

    elif chosen_id == '2':
        await sidebar('bulk_prediction')
        df_with_predictions = await do_bulk_prediction(model_endpoint)
        if df_with_predictions is None:
            df_with_predictions = st.session_state.get(
                'bulk_prediction_df', None)
        await show_bulk_predictions(df_with_predictions)

    # Add footer
    await footer()


if __name__ == "__main__":
    asyncio.run(main())