import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

from forecast import get_forecast_data
from retrieve_coefs_max_yield import get_coefs_Kc_Ky_and_max_yield
from utils.soil_utils import get_soil_properties


def calculate_ETx(Kc, ETo):
    """
    Calculate the maximum evapotranspiration (ETx) using the crop coefficient (Kc) and reference evapotranspiration (ETo).

    Parameters:
    Kc (float): Crop coefficient
    ETo (float): Reference evapotranspiration (mm)

    Returns:
    float: Maximum evapotranspiration (ETx) in mm
    """
    ETx = Kc * ETo
    return ETx


def calculate_ETa(ETx, soil_moisture, field_capacity, wilting_point, water_deficit, ETo):
    """
    Calculate the actual evapotranspiration (ETa) using the maximum evapotranspiration (ETx), soil moisture, field capacity, and wilting point.

    Parameters:
    ETx (float): Maximum evapotranspiration (mm)
    soil_moisture (Series): Current soil moisture content (%)
    field_capacity (float): Field capacity of the soil (%)
    wilting_point (float): Wilting point of the soil (%)

    Returns:
    float: Actual evapotranspiration (ETa) in mm
    """
    Ks = 1 - (water_deficit / ETo) # coef de stress hydrique = precipitation / et0
    Ks = Ks.clip(lower=0, upper=1)
    ETa = ETx * Ks

    ETa.loc[soil_moisture > field_capacity] = ETx.loc[soil_moisture > field_capacity]
    ETa.loc[soil_moisture < wilting_point] = 0

    return ETa


def calculate_yield_projection(Yx, ETx, ETa, Ky):
    """
    Calculate the agricultural yield projection using the FAO water production function.

    Parameters:
    Yx (float): Maximum yield (quintal/ha)
    ETx (float): Maximum evapotranspiration (mm)
    ETa (float): Actual evapotranspiration (mm)
    Ky (float): Yield response factor

    Returns:
    float: Projected yield (quintal/ha)
    """

    Ya = Yx * (1 - Ky * (1 - ETa / ETx))
    Ya.loc[ETx == 0] = 0

    return round(Ya, 2)


def add_cultural_coefs(monthly_forecast: pd.DataFrame, cultural_coefs: pd.DataFrame) -> pd.DataFrame:
    monthly_forecast["Kc"] = 0
    monthly_forecast["Ky"] = 0
    for month in range(1, 13):
        Kc = cultural_coefs["Kc"][cultural_coefs.Mois == month].iloc[0]
        Ky = cultural_coefs["Ky"][cultural_coefs.Mois == month].iloc[0]
        monthly_forecast.loc[(monthly_forecast.month==month).to_numpy(), "Kc"] = Kc
        monthly_forecast.loc[(monthly_forecast.month==month).to_numpy(), "Ky"] = Ky
    return monthly_forecast


def compute_yield_forecast(
        latitude: float,
        longitude: float,
        culture: str = "Colza d'hiver",
        region: str = "Bourgogne-Franche-Comté",
        scenario: str = "pessimist",
        shading_coef: float = 0.,
):
    monthly_forecast = get_forecast_data(latitude, longitude, scenario=scenario, shading_coef=shading_coef)

    cultural_coefs, max_yield = get_coefs_Kc_Ky_and_max_yield(culture, region)
    monthly_forecast = add_cultural_coefs(monthly_forecast, cultural_coefs)
    Kc = monthly_forecast["Kc"]
    Ky = monthly_forecast["Ky"]

    soil_properties = get_soil_properties(latitude, longitude)

    ETo = monthly_forecast["Evaporation (mm/day)"]

    ETx = calculate_ETx(Kc, ETo)

    ETa = calculate_ETa(
        ETx,
        monthly_forecast["Moisture in Upper Portion of Soil Column (kg m-2)"],
        soil_properties["field_capacity"],
        soil_properties["wilting_point"],
        water_deficit=monthly_forecast["Water Deficit (mm/day)"],
        ETo=ETo,
    )

    projected_yield = calculate_yield_projection(
        Yx=max_yield,
        ETx=ETx,
        ETa=ETa,
        Ky=Ky)
    monthly_forecast["Estimated yield (quintal/ha)"] = projected_yield

    return monthly_forecast


def get_annual_yield(monthly_forecast: pd.DataFrame) -> pd.Series:
    yield_forecast = pd.Series(
        index=monthly_forecast["time"],
        data=monthly_forecast["Estimated yield (quintal/ha)"].to_numpy(),
    )
    yield_forecast = yield_forecast.resample("1YE").mean()
    return yield_forecast


def plot_yield(
        latitude: float,
        longitude: float,
        culture: str = "Colza d'hiver",
        region: str = "Bourgogne-Franche-Comté",
        scenario: str = "pessimist",
        shading_coef: float = 0.,
) -> plt.Figure:
    monthly_forecast = compute_yield_forecast(
        latitude=latitude,
        longitude=longitude,
        culture=culture,
        scenario=scenario,
        shading_coef=0.,
    )

    monthly_forecast_with_shading = compute_yield_forecast(
        latitude=latitude,
        longitude=longitude,
        culture=culture,
        scenario=scenario,
        shading_coef=shading_coef,
    )

    yield_forecast = get_annual_yield(monthly_forecast)
    yield_forecast_with_shading = get_annual_yield(monthly_forecast_with_shading)

    n_years = 10
    years = 2025 + np.arange(len(yield_forecast_with_shading))
    aggregated_forecasts = yield_forecast.rolling(n_years).sum()[years % n_years == 0]
    aggregated_forecasts_with_shading = yield_forecast_with_shading.rolling(n_years).sum()[years % n_years == 0]

    width = 3  # the width of the bars
    fig, ax = plt.subplots(layout='constrained')
    aggregated_years = years[years % n_years == 0]
    _ = ax.bar(aggregated_years, aggregated_forecasts, width, label="No shading")
    _ = ax.bar(aggregated_years + width, aggregated_forecasts_with_shading, width, label="20% shading")
    ax.legend()
    ax.set_ylim(150)

    return fig

if __name__ == '__main__':
    culture = "Colza d'hiver"
    scenario = "pessimist"
    shading_coef = 0.2
    monthly_forecast = compute_yield_forecast(
        latitude=47,
        longitude=5,
        culture=culture,
        scenario=scenario,
        shading_coef=0.,
    )
    # print(monthly_forecast.head())

    yield_forecast = get_annual_yield(monthly_forecast)
    # print(yield_forecast)

    monthly_forecast_with_shading = compute_yield_forecast(
        latitude=47,
        longitude=5,
        culture=culture,
        scenario=scenario,
        shading_coef=shading_coef,
    )
    # print(monthly_forecast_with_shading.head())

    yield_forecast_with_shading = get_annual_yield(monthly_forecast_with_shading)
    # print(yield_forecast)

    n_years = 10
    years = 2025 + np.arange(len(yield_forecast_with_shading))
    aggregated_forecasts = yield_forecast.rolling(n_years).sum()[years % n_years == 0]
    aggregated_forecasts_with_shading = yield_forecast_with_shading.rolling(n_years).sum()[years % n_years == 0]
    # plt.plot(yield_forecast.rolling(n_years).sum(), label="No shading")
    # plt.plot(yield_forecast_with_shading.rolling(n_years).sum(), label="20% Shading")
    # plt.bar(years[years % n_years == 0], aggregated_forecasts, label="No shading")
    # plt.bar(years[years % n_years == 0], aggregated_forecasts_with_shading, label="20% Shading")

    width = 3  # the width of the bars
    fig, ax = plt.subplots(layout='constrained')

    aggregated_years = years[years % n_years == 0]
    rects = ax.bar(aggregated_years, aggregated_forecasts, width, label="No shading")
    rects2 = ax.bar(aggregated_years + width, aggregated_forecasts_with_shading, width, label="20% shading")

    plt.legend()
    plt.ylim(150)
    plt.show()