Spaces:
Runtime error
Runtime error
File size: 3,136 Bytes
7e96883 d219cf7 7e96883 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import yfinance as yf
st.set_page_config(
page_title="US Stock Forecast",
page_icon="logo.png",
menu_items=None
)
st.write("# US Stocks Forecast")
# Importing forecasting algorithms from the algo directory
from algo.sarima import sarima_forecast
from algo.linear_regression import linear_regression_forecast
from algo.tbats import tbats_forecast
from algo.random_forest import random_forest_forecast
# Function to fetch stock data
def fetch_stock_data(ticker, start_date, end_date):
data = yf.download(ticker, start=start_date, end=end_date)
return data['Close']
# Function to plot forecasts
def plot_forecasts(data, forecasts, title='Stock Price Forecast'):
plt.figure(figsize=(10, 6))
plt.plot(data.index, data, label='Historical Prices', color='black', alpha=0.75)
for name, forecast in forecasts.items():
plt.plot(forecast.index, forecast, label=name)
if len(forecasts) > 1:
combined_forecast = pd.concat(forecasts.values()).groupby(level=0).mean()
plt.plot(combined_forecast.index, combined_forecast, label='Combined Forecast', color='red', linestyle='--')
plt.title(title)
plt.xlabel('Date')
plt.ylabel('Price')
plt.legend()
st.pyplot(plt)
# Streamlit UI in Sidebar
st.sidebar.title("Input Parameters")
ticker = st.sidebar.text_input('Enter Ticker Symbol', 'AAPL')
start_date = st.sidebar.date_input('Select Start Date', value=pd.to_datetime('2020-01-01'))
end_date = st.sidebar.date_input('Select End Date', value=pd.to_datetime('2023-01-01'))
forecast_horizon = st.sidebar.number_input('Forecast Horizon (days)', min_value=1, value=180)
forecast_date = st.sidebar.date_input('Forecast Date', min_value=end_date, value=end_date + pd.Timedelta(days=180))
# User selects which forecasting models to use in Sidebar
options = st.sidebar.multiselect('Select forecasting models to use',
['SARIMA', 'Linear Regression', 'TBATS', 'Random Forest'],
['SARIMA', 'Linear Regression'])
if st.sidebar.button('Analyze'):
data = fetch_stock_data(ticker, start_date, end_date)
forecasts = {}
if 'SARIMA' in options:
forecasts['SARIMA'] = sarima_forecast(data, forecast_horizon)
if 'Linear Regression' in options:
forecasts['Linear Regression'] = linear_regression_forecast(data, forecast_horizon)
if 'TBATS' in options:
forecasts['TBATS'] = tbats_forecast(data, forecast_horizon)
if 'Random Forest' in options:
forecasts['Random Forest'] = random_forest_forecast(data, forecast_horizon)
plot_forecasts(data, forecasts, f"Forecasted Stock Prices for {ticker}")
# Output the forecasted price for the selected date, if available
forecast_date_str = forecast_date.strftime('%Y-%m-%d')
for model_name, forecast in forecasts.items():
if forecast_date_str in forecast.index:
st.write(f"Forecasted price by {model_name} on {forecast_date_str}: {forecast.loc[forecast_date_str]:.2f}")
|