Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import yfinance as yf | |
st.set_page_config( | |
page_title="US Stock Forecast", | |
page_icon="logo.png", | |
menu_items=None | |
) | |
st.write("# US Stocks Forecast") | |
# Importing forecasting algorithms from the algo directory | |
from algo.sarima import sarima_forecast | |
from algo.linear_regression import linear_regression_forecast | |
from algo.tbats import tbats_forecast | |
from algo.random_forest import random_forest_forecast | |
# Function to fetch stock data | |
def fetch_stock_data(ticker, start_date, end_date): | |
data = yf.download(ticker, start=start_date, end=end_date) | |
return data['Close'] | |
# Function to plot forecasts | |
def plot_forecasts(data, forecasts, title='Stock Price Forecast'): | |
plt.figure(figsize=(10, 6)) | |
plt.plot(data.index, data, label='Historical Prices', color='black', alpha=0.75) | |
for name, forecast in forecasts.items(): | |
plt.plot(forecast.index, forecast, label=name) | |
if len(forecasts) > 1: | |
combined_forecast = pd.concat(forecasts.values()).groupby(level=0).mean() | |
plt.plot(combined_forecast.index, combined_forecast, label='Combined Forecast', color='red', linestyle='--') | |
plt.title(title) | |
plt.xlabel('Date') | |
plt.ylabel('Price') | |
plt.legend() | |
st.pyplot(plt) | |
# Streamlit UI in Sidebar | |
st.sidebar.title("Input Parameters") | |
ticker = st.sidebar.text_input('Enter Ticker Symbol', 'AAPL') | |
start_date = st.sidebar.date_input('Select Start Date', value=pd.to_datetime('2020-01-01')) | |
end_date = st.sidebar.date_input('Select End Date', value=pd.to_datetime('2023-01-01')) | |
forecast_horizon = st.sidebar.number_input('Forecast Horizon (days)', min_value=1, value=180) | |
forecast_date = st.sidebar.date_input('Forecast Date', min_value=end_date, value=end_date + pd.Timedelta(days=180)) | |
# User selects which forecasting models to use in Sidebar | |
options = st.sidebar.multiselect('Select forecasting models to use', | |
['SARIMA', 'Linear Regression', 'TBATS', 'Random Forest'], | |
['SARIMA', 'Linear Regression']) | |
if st.sidebar.button('Analyze'): | |
data = fetch_stock_data(ticker, start_date, end_date) | |
forecasts = {} | |
if 'SARIMA' in options: | |
forecasts['SARIMA'] = sarima_forecast(data, forecast_horizon) | |
if 'Linear Regression' in options: | |
forecasts['Linear Regression'] = linear_regression_forecast(data, forecast_horizon) | |
if 'TBATS' in options: | |
forecasts['TBATS'] = tbats_forecast(data, forecast_horizon) | |
if 'Random Forest' in options: | |
forecasts['Random Forest'] = random_forest_forecast(data, forecast_horizon) | |
plot_forecasts(data, forecasts, f"Forecasted Stock Prices for {ticker}") | |
# Output the forecasted price for the selected date, if available | |
forecast_date_str = forecast_date.strftime('%Y-%m-%d') | |
for model_name, forecast in forecasts.items(): | |
if forecast_date_str in forecast.index: | |
st.write(f"Forecasted price by {model_name} on {forecast_date_str}: {forecast.loc[forecast_date_str]:.2f}") | |