Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import yfinance as yf
|
6 |
+
|
7 |
+
# Importing forecasting algorithms from the algo directory
|
8 |
+
from algo.sarima import sarima_forecast
|
9 |
+
from algo.linear_regression import linear_regression_forecast
|
10 |
+
from algo.tbats import tbats_forecast
|
11 |
+
from algo.random_forest import random_forest_forecast
|
12 |
+
|
13 |
+
# Function to fetch stock data
|
14 |
+
def fetch_stock_data(ticker, start_date, end_date):
|
15 |
+
data = yf.download(ticker, start=start_date, end=end_date)
|
16 |
+
return data['Close']
|
17 |
+
|
18 |
+
# Function to plot forecasts
|
19 |
+
def plot_forecasts(data, forecasts, title='Stock Price Forecast'):
|
20 |
+
plt.figure(figsize=(10, 6))
|
21 |
+
plt.plot(data.index, data, label='Historical Prices', color='black', alpha=0.75)
|
22 |
+
|
23 |
+
for name, forecast in forecasts.items():
|
24 |
+
plt.plot(forecast.index, forecast, label=name)
|
25 |
+
|
26 |
+
if len(forecasts) > 1:
|
27 |
+
combined_forecast = pd.concat(forecasts.values()).groupby(level=0).mean()
|
28 |
+
plt.plot(combined_forecast.index, combined_forecast, label='Combined Forecast', color='red', linestyle='--')
|
29 |
+
|
30 |
+
plt.title(title)
|
31 |
+
plt.xlabel('Date')
|
32 |
+
plt.ylabel('Price')
|
33 |
+
plt.legend()
|
34 |
+
st.pyplot(plt)
|
35 |
+
|
36 |
+
# Streamlit UI in Sidebar
|
37 |
+
st.sidebar.title("Input Parameters")
|
38 |
+
ticker = st.sidebar.text_input('Enter Ticker Symbol', 'AAPL')
|
39 |
+
start_date = st.sidebar.date_input('Select Start Date', value=pd.to_datetime('2020-01-01'))
|
40 |
+
end_date = st.sidebar.date_input('Select End Date', value=pd.to_datetime('2023-01-01'))
|
41 |
+
forecast_horizon = st.sidebar.number_input('Forecast Horizon (days)', min_value=1, value=180)
|
42 |
+
forecast_date = st.sidebar.date_input('Forecast Date', min_value=end_date, value=end_date + pd.Timedelta(days=180))
|
43 |
+
|
44 |
+
# User selects which forecasting models to use in Sidebar
|
45 |
+
options = st.sidebar.multiselect('Select forecasting models to use',
|
46 |
+
['SARIMA', 'Linear Regression', 'TBATS', 'Random Forest'],
|
47 |
+
['SARIMA', 'Linear Regression'])
|
48 |
+
|
49 |
+
if st.sidebar.button('Analyze'):
|
50 |
+
data = fetch_stock_data(ticker, start_date, end_date)
|
51 |
+
forecasts = {}
|
52 |
+
|
53 |
+
if 'SARIMA' in options:
|
54 |
+
forecasts['SARIMA'] = sarima_forecast(data, forecast_horizon)
|
55 |
+
if 'Linear Regression' in options:
|
56 |
+
forecasts['Linear Regression'] = linear_regression_forecast(data, forecast_horizon)
|
57 |
+
if 'TBATS' in options:
|
58 |
+
forecasts['TBATS'] = tbats_forecast(data, forecast_horizon)
|
59 |
+
if 'Random Forest' in options:
|
60 |
+
forecasts['Random Forest'] = random_forest_forecast(data, forecast_horizon)
|
61 |
+
|
62 |
+
plot_forecasts(data, forecasts, f"Forecasted Stock Prices for {ticker}")
|
63 |
+
|
64 |
+
# Output the forecasted price for the selected date, if available
|
65 |
+
forecast_date_str = forecast_date.strftime('%Y-%m-%d')
|
66 |
+
for model_name, forecast in forecasts.items():
|
67 |
+
if forecast_date_str in forecast.index:
|
68 |
+
st.write(f"Forecasted price by {model_name} on {forecast_date_str}: {forecast.loc[forecast_date_str]:.2f}")
|