import matplotlib.pyplot as plt
import pandas as pd
import streamlit as st


def load_data(year):
    """Load data from a CSV file for the given year."""
    try:
        data = pd.read_csv(f"validation/{year}.csv")
        return data
    except FileNotFoundError:
        st.error(f"No data found for year {year}. Please ensure the file exists.")
        return None

def filter_data(data, country, brand):
    """Filter the data for the selected country and brand."""
    return data[(data['country'] == country) & (data['brand'] == brand)]

def plot_data(filtered_data):
    """Plot target vs. date with a confidence interval."""
    if filtered_data.empty:
        st.warning("No data available for the selected criteria.")
        return
    
    st.write("Plotting target vs date with confidence intervals.")
    dates = pd.to_datetime(filtered_data['date'])
    target = filtered_data['target']
    prediction = filtered_data['prediction']
    prediction_10 = filtered_data['prediction_10']
    prediction_90 = filtered_data['prediction_90']

    plt.figure(figsize=(7, 7))
    
    # Plot the target
    plt.plot(dates, target, label='Target', color='blue')

    # Plot the prediction with confidence interval
    plt.plot(dates, prediction, label='Prediction', color='orange')
    plt.fill_between(dates, prediction_10, prediction_90, color='orange', alpha=0.2, label='Confidence Interval (10th to 90th percentile)')
    
    plt.xlabel('Date')
    plt.ylabel('Target')
    plt.title('Target vs Date with Confidence Interval')
    plt.legend()
    plt.grid(True)
    st.pyplot(plt)

def main():
    st.title("Data Visualization App")

    # Step 1: Select Year, default to 2021
    year = st.sidebar.selectbox("Select Year", range(2017, 2022), index=4)
    
    # Load data based on year selection
    data = load_data(year)
    
    if data is not None:
        # Step 2: Select Country based on available options for the year
        available_countries = data['country'].unique()
        # default to COUNTRY_6B71
        available_countries = ["COUNTRY_6B71"] + list(x for x in available_countries if x != "COUNTRY_6B71")
        country = st.sidebar.selectbox("Select Country", available_countries)
        
        # Step 3: Select Brand based on available options for the year and country
        available_brands = data[data['country'] == country]['brand'].unique()
        # default to BRAND_24CB
        available_brands = ["BRAND_24CB"] + list(x for x in available_brands if x != "BRAND_24CB")
        brand = st.sidebar.selectbox("Select Brand", available_brands)
        
        # Filter data based on inputs
        filtered_data = filter_data(data, country, brand)

        # Plot the data
        plot_data(filtered_data)

if __name__ == "__main__":
    main()