import pandas as pd  # stress hydrique and rendement, besoin en eau
import plotly.graph_objects as go
from typing import List
import plotly.express as px
from data_pipelines.historical_weather_data import (
    download_historical_weather_data,
    aggregate_hourly_weather_data,
)
import os
from forecast import get_forecast_data
from compute_et0_adjusted import compute_et0


def water_deficit(df, latitude, longitude, shading_coef=0, historic=True):
    preprocessed_data = df.copy()
    preprocessed_data["irradiance"] = preprocessed_data[
        "Surface Downwelling Shortwave Radiation (W/m²)"
    ] * (1 - shading_coef)
    preprocessed_data["air_temperature_min"] = preprocessed_data[
        "Daily Minimum Near Surface Air Temperature (°C)"
    ]
    preprocessed_data["air_temperature_max"] = preprocessed_data[
        "Daily Maximum Near Surface Air Temperature (°C)"
    ]
    if historic == True:
        preprocessed_data["relative_humidity_min"] = preprocessed_data[
            "Relative Humidity_min"
        ]
        preprocessed_data["relative_humidity_max"] = preprocessed_data[
            "Relative Humidity_max"
        ]
    else:
        preprocessed_data["relative_humidity_min"] = preprocessed_data[
            "Relative Humidity (%)"
        ]
        preprocessed_data["relative_humidity_max"] = preprocessed_data[
            "Relative Humidity (%)"
        ]
    preprocessed_data["wind_speed"] = preprocessed_data["Near Surface Wind Speed (m/s)"]

    # Convert 'time' to datetime and calculate Julian day
    preprocessed_data["time"] = pd.to_datetime(
        preprocessed_data["time"], errors="coerce"
    )
    preprocessed_data["month"] = preprocessed_data["time"].dt.month
    preprocessed_data["day_of_year"] = preprocessed_data["time"].dt.dayofyear

    # Compute ET0
    et0 = compute_et0(preprocessed_data, latitude, longitude)
    preprocessed_data["Evaporation (mm/day)"] = et0
    preprocessed_data["Evaporation (mm/day)"] = preprocessed_data[
        "Evaporation (mm/day)"
    ].clip(lower=0)
    # Convert Precipitation from kg/m²/s to mm/day

    preprocessed_data["Precipitation (mm/day)"] = (
        86400 * preprocessed_data["Precipitation (kg m-2 s-1)"]
    )

    # Calculate Water Deficit: Water Deficit = ET0 - P + M
    preprocessed_data["Water Deficit (mm/day)"] = (
        preprocessed_data["Evaporation (mm/day)"]
        - preprocessed_data["Precipitation (mm/day)"]
        + 4.5
    )

    return preprocessed_data


def concatenate_historic_forecast(
    historic, forecast, cols_to_keep, value_period_col="forecast scénario modéré"
):
    historic["period"] = "historique"
    forecast["period"] = value_period_col
    historic = historic[cols_to_keep]
    forecast = forecast[cols_to_keep]
    full_data = pd.concat([historic, forecast])
    return full_data


def visualize_climate(
    moderate: pd.DataFrame,
    historic: pd.DataFrame,
    pessimist: pd.DataFrame,
    x_axis="year",
    column: str = "Precipitation (mm)",
    cols_to_keep: List[str] = [
        "Precipitation (mm)",
        "Near Surface Air Temperature (°C)",
        "Surface Downwelling Shortwave Radiation (W/m²)",
        "Water Deficit (mm/day)",
        "year",
        "period",
    ],
):
    concatenated_moderate = concatenate_historic_forecast(
        historic, moderate, cols_to_keep
    )
    concatenated_moderate = concatenated_moderate.sort_values(by=x_axis)  # Ensure order

    fig = go.Figure()

    if column == "Precipitation (mm)":
        for condition_value in concatenated_moderate["period"].unique():
            segment = concatenated_moderate[
                concatenated_moderate["period"] == condition_value
            ]
            avg_precipitation = segment.groupby(x_axis)[column].mean().reset_index()

            fig.add_trace(
                go.Bar(
                    x=avg_precipitation[x_axis],
                    y=avg_precipitation[column],
                    name=f"{condition_value}",
                    marker=dict(
                        color="blue" if condition_value == "historique" else "purple"
                    ),
                )
            )

        concatenated_pessimist = concatenate_historic_forecast(
            historic, pessimist, cols_to_keep, "forecast scénario pessimiste"
        )
        concatenated_pessimist = concatenated_pessimist.sort_values(by=x_axis)
        concatenated_pessimist = concatenated_pessimist[
            concatenated_pessimist["period"] != "historique"
        ]
        for condition_value in concatenated_pessimist["period"].unique():
            segment = concatenated_pessimist[
                concatenated_pessimist["period"] == condition_value
            ]
            avg_precipitation = segment.groupby(x_axis)[column].mean().reset_index()

            fig.add_trace(
                go.Bar(
                    x=avg_precipitation[x_axis],
                    y=avg_precipitation[column],
                    name=f"{condition_value}",
                    marker=dict(
                        color="orange" if condition_value != "historique" else "blue"
                    ),
                )
            )

        # Update layout for bar chart
        fig.update_layout(
            title=f"Moyenne de {column} par année",
            xaxis_title="Année",  # Set the x-axis title to Year
            yaxis_title="Précipitation (mm)",  # Set the y-axis title to Precipitation
            barmode="group",  # Group bars for different conditions
        )

    else:
        # For other columns, continue with the line plot as before
        for condition_value in concatenated_moderate["period"].unique():
            segment = concatenated_moderate[
                concatenated_moderate["period"] == condition_value
            ]
            if condition_value == "historique":
                fig.add_trace(
                    go.Scatter(
                        x=segment[x_axis],  # Years on x-axis
                        y=segment[column],  # Precipitation values on y-axis
                        mode="lines",
                        name=f"{condition_value}",
                        legendgroup="group1",
                        showlegend=False,
                        line=dict(
                            color=(
                                "blue" if condition_value == "historique" else "purple"
                            )
                        ),
                    )
                )
            else:
                fig.add_trace(
                    go.Scatter(
                        x=segment[x_axis],  # Years on x-axis
                        y=segment[column],  # Precipitation values on y-axis
                        mode="lines",
                        name=f"{condition_value}",
                        legendgroup="group2",
                        showlegend=False,
                        line=dict(
                            color=(
                                "blue" if condition_value == "historique" else "purple"
                            ),
                            dash="dot",
                        ),
                    )
                )

        # Continue with pessimistic data as in the original function...
        concatenated_pessimist = concatenate_historic_forecast(
            historic, pessimist, cols_to_keep, "forecast scénario pessimiste"
        )
        concatenated_pessimist = concatenated_pessimist.sort_values(by=x_axis)
        for condition_value in concatenated_pessimist["period"].unique():
            segment = concatenated_pessimist[
                concatenated_pessimist["period"] == condition_value
            ]
            if condition_value == "historique":
                fig.add_trace(
                    go.Scatter(
                        x=segment[x_axis],  # Years on x-axis
                        y=segment[column],  # Precipitation values on y-axis
                        mode="lines",
                        name=f"{condition_value}",
                        legendgroup="group1",
                        line=dict(
                            color=(
                                "blue" if condition_value == "historique" else "orange"
                            ),
                            dash="dot" if condition_value != "historique" else None,
                        ),
                    )
                )
            else:
                fig.add_trace(
                    go.Scatter(
                        x=segment[x_axis],  # Years on x-axis
                        y=segment[column],  # Precipitation values on y-axis
                        mode="lines",
                        name=f"{condition_value}",
                        legendgroup="group3",
                        line=dict(
                            color=(
                                "blue" if condition_value == "historique" else "orange"
                            ),
                            dash="dot" if condition_value != "historique" else None,
                        ),
                    )
                )
        # Interpolation for the pessimistic scenario...
        interpolation_pessimist = concatenated_pessimist[
            concatenated_pessimist[x_axis] > 2023
        ]
        interpolation_pessimist = interpolation_pessimist[
            interpolation_pessimist[x_axis] <= 2025
        ]
        fig.add_trace(
            go.Scatter(
                x=interpolation_pessimist[x_axis],
                y=interpolation_pessimist[column].interpolate(),
                mode="lines",
                name="forecast scénario pessimiste",
                legendgroup="group3",
                showlegend=False,
                line=dict(color="orange", dash="dot"),
            ),
        )
        interpolation_moderate = concatenated_moderate[
            concatenated_moderate[x_axis] > 2023
        ]
        interpolation_moderate = interpolation_moderate[
            interpolation_moderate[x_axis] <= 2025
        ]
        fig.add_trace(
            go.Scatter(
                x=interpolation_moderate[x_axis],
                y=interpolation_moderate[column].interpolate(),
                mode="lines",
                name="forecast scénario modéré",
                legendgroup="group2",
                line=dict(color="purple", dash="dot"),
            ),
        )
        fig.update_layout(
            title=f"Historique et Forecast pour {column}",
            xaxis_title="Year",  # Set the x-axis title to Year
            yaxis_title=column,  # Set the y-axis title to Precipitation
        )

    return fig


def aggregate_yearly(df, col_to_agg, operation="mean"):
    df[col_to_agg] = df.groupby("year")[col_to_agg].transform(operation)
    return df


def generate_plots(
    moderate: pd.DataFrame,
    historic: pd.DataFrame,
    pessimist: pd.DataFrame,
    x_axes: List[str],
    cols_to_plot: List[str],
):
    plots = []
    for i, col in enumerate(cols_to_plot):
        plots.append(visualize_climate(moderate, historic, pessimist, x_axes[i], col))
    return plots


def get_plots():
    cols_to_plot = [
        "Precipitation (mm)",
        "Near Surface Air Temperature (°C)",
        "Surface Downwelling Shortwave Radiation (W/m²)",
        'Water Deficit (mm/day)'
    ]
    cols_to_keep: List[str] = [
        "Precipitation (mm)",
        "Near Surface Air Temperature (°C)",
        "Surface Downwelling Shortwave Radiation (W/m²)",
        "Water Deficit (mm/day)",
        "year",
        "period",
    ]
    x_axes = ["year"] * len(cols_to_plot)
    latitude = 47
    longitude = 5
    start_year = 2000
    end_year = 2025

    df = download_historical_weather_data(latitude, longitude, start_year, end_year)
    historic = aggregate_hourly_weather_data(df)
    historic= historic.reset_index()
    historic = historic.rename(
        columns={
            "precipitation": "Precipitation (mm)",
            "air_temperature_mean": "Near Surface Air Temperature (°C)",
            "irradiance": "Surface Downwelling Shortwave Radiation (W/m²)",
            'index': 'time'
        }
    )
    historic["time"] = pd.to_datetime(historic["time"])
    historic = historic.sort_values('time')
    historic = historic[historic["time"]<"2025-01-01"]
    historic = historic.rename(columns={"air_temperature_min":"Daily Minimum Near Surface Air Temperature (°C)",
                                        "air_temperature_max":"Daily Maximum Near Surface Air Temperature (°C)",
                                        "relative_humidity_min": 'Relative Humidity_min',
                                        "relative_humidity_max": 'Relative Humidity_max',
                                        "wind_speed":"Near Surface Wind Speed (m/s)",
                                        'Precipitation (mm)':'Precipitation (kg m-2 s-1)'
                                        })
    historic["Precipitation (kg m-2 s-1)"] = historic["Precipitation (kg m-2 s-1)"]/3600

    historic = water_deficit(historic,latitude,longitude)
    historic = historic.rename(columns={'Precipitation (kg m-2 s-1)':'Precipitation (mm)'
                                        })
    historic['Precipitation (mm)'] = historic['Precipitation (mm)']*3600

    moderate = get_forecast_data(latitude, longitude, "moderate")
    pessimist = get_forecast_data(latitude, longitude, "pessimist")
    moderate = moderate.rename(
        columns={"Precipitation (kg m-2 s-1)": "Precipitation (mm)"}
    )
    moderate["time"] = pd.to_datetime(moderate["time"])
    moderate = moderate.sort_values("time")
    moderate["year"] = moderate["time"].dt.year
    moderate["Precipitation (mm)"] = moderate["Precipitation (mm)"] * 31536000
    pessimist = pessimist.rename(
        columns={"Precipitation (kg m-2 s-1)": "Precipitation (mm)"}
    )
    pessimist["time"] = pd.to_datetime(pessimist["time"])
    pessimist = pessimist.sort_values("time")
    pessimist["year"] = pessimist["time"].dt.year
    pessimist["Precipitation (mm)"] = pessimist["Precipitation (mm)"] * 31536000
    pessimist["period"] = "forecast scénario pessimiste"
    historic["year"] = historic["time"].dt.year
    historic["Precipitation (mm)"] = historic["Precipitation (mm)"] * 8760.0
    for col in cols_to_plot:
        moderate = aggregate_yearly(moderate, col)
        historic = aggregate_yearly(historic, col)
        pessimist = aggregate_yearly(pessimist, col)
    plots = generate_plots(moderate, historic, pessimist, x_axes, cols_to_plot)
    return plots