import streamlit as st import json from utils import load_and_process_data, create_time_series_plot, display_statistics, call_api from dotenv import load_dotenv if 'api_token' not in st.session_state: st.session_state.api_token = "p2s8X9qL4zF7vN3mK6tR1bY5cA0wE3hJ" # Clear other states for key in ['current_file', 'json_data', 'api_response']: if key in st.session_state: del st.session_state[key] # Initialize session state variables if 'current_file' not in st.session_state: st.session_state.current_file = None if 'json_data' not in st.session_state: st.session_state.json_data = None if 'api_response' not in st.session_state: st.session_state.api_response = None st.title("Short Term Energy Consumption Forecasting") st.markdown(""" This service provides short-term forecasting of energy consumption patterns. Upload your energy consumption data to generate predictions for the near future. ### Features - Hourly consumption forecasting - Interactive visualizations - Statistical analysis of predictions """) # File upload and processing uploaded_file = st.file_uploader("Upload JSON file", type=['json']) if uploaded_file: try: file_contents = uploaded_file.read() st.session_state.current_file = file_contents st.session_state.json_data = json.loads(file_contents) dfs = load_and_process_data(st.session_state.json_data) if dfs: st.header("Input Data") tabs = st.tabs(["Visualization", "Raw JSON", "Statistics"]) with tabs[0]: for unit, df in dfs.items(): st.plotly_chart(create_time_series_plot(df, unit), use_container_width=True) with tabs[1]: st.json(st.session_state.json_data) with tabs[2]: display_statistics(dfs) if st.button("Generate Short Term Forecast"): if not st.session_state.api_token: st.error("Please enter your API token in the sidebar first.") else: with st.spinner("Generating forecast..."): st.session_state.api_response = call_api( st.session_state.current_file, st.session_state.api_token, "inference_consumption_short_term" ) except Exception as e: st.error(f"Error processing file: {str(e)}") # Display API results if st.session_state.api_response: st.header("Forecast Results") tabs = st.tabs(["Visualization", "Raw JSON", "Statistics"]) with tabs[0]: response_dfs = load_and_process_data( st.session_state.api_response, input_data=st.session_state.json_data ) if response_dfs: del response_dfs['Celsius'] for unit, df in response_dfs.items(): st.plotly_chart(create_time_series_plot(df, unit), use_container_width=True) with tabs[1]: st.json(st.session_state.api_response) with tabs[2]: if response_dfs: display_statistics(response_dfs)