import matplotlib.pyplot as plt import pandas as pd import streamlit as st def load_data(year): """Load data from a CSV file for the given year.""" try: data = pd.read_csv(f"validation/{year}.csv") return data except FileNotFoundError: st.error(f"No data found for year {year}. Please ensure the file exists.") return None def filter_data(data, country, brand): """Filter the data for the selected country and brand.""" return data[(data['country'] == country) & (data['brand'] == brand)] def plot_data(filtered_data): """Plot target vs. date with a confidence interval.""" if filtered_data.empty: st.warning("No data available for the selected criteria.") return st.write("Plotting target vs date with confidence intervals.") dates = pd.to_datetime(filtered_data['date']) target = filtered_data['target'] prediction = filtered_data['prediction'] prediction_10 = filtered_data['prediction_10'] prediction_90 = filtered_data['prediction_90'] plt.figure(figsize=(7, 7)) # Plot the target plt.plot(dates, target, label='Target', color='blue') # Plot the prediction with confidence interval plt.plot(dates, prediction, label='Prediction', color='orange') plt.fill_between(dates, prediction_10, prediction_90, color='orange', alpha=0.2, label='Confidence Interval (10th to 90th percentile)') plt.xlabel('Date') plt.ylabel('Target') plt.title('Target vs Date with Confidence Interval') plt.legend() plt.grid(True) st.pyplot(plt) def main(): st.title("Data Visualization App") # Step 1: Select Year, default to 2021 year = st.sidebar.selectbox("Select Year", range(2017, 2022), index=4) # Load data based on year selection data = load_data(year) if data is not None: # Step 2: Select Country based on available options for the year available_countries = data['country'].unique() # default to COUNTRY_6B71 available_countries = ["COUNTRY_6B71"] + list(x for x in available_countries if x != "COUNTRY_6B71") country = st.sidebar.selectbox("Select Country", available_countries) # Step 3: Select Brand based on available options for the year and country available_brands = data[data['country'] == country]['brand'].unique() # default to BRAND_24CB available_brands = ["BRAND_24CB"] + list(x for x in available_brands if x != "BRAND_24CB") brand = st.sidebar.selectbox("Select Brand", available_brands) # Filter data based on inputs filtered_data = filter_data(data, country, brand) # Plot the data plot_data(filtered_data) if __name__ == "__main__": main()