netflypsb commited on
Commit
7e96883
·
verified ·
1 Parent(s): cee3786

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import yfinance as yf
6
+
7
+ # Importing forecasting algorithms from the algo directory
8
+ from algo.sarima import sarima_forecast
9
+ from algo.linear_regression import linear_regression_forecast
10
+ from algo.tbats import tbats_forecast
11
+ from algo.random_forest import random_forest_forecast
12
+
13
+ # Function to fetch stock data
14
+ def fetch_stock_data(ticker, start_date, end_date):
15
+ data = yf.download(ticker, start=start_date, end=end_date)
16
+ return data['Close']
17
+
18
+ # Function to plot forecasts
19
+ def plot_forecasts(data, forecasts, title='Stock Price Forecast'):
20
+ plt.figure(figsize=(10, 6))
21
+ plt.plot(data.index, data, label='Historical Prices', color='black', alpha=0.75)
22
+
23
+ for name, forecast in forecasts.items():
24
+ plt.plot(forecast.index, forecast, label=name)
25
+
26
+ if len(forecasts) > 1:
27
+ combined_forecast = pd.concat(forecasts.values()).groupby(level=0).mean()
28
+ plt.plot(combined_forecast.index, combined_forecast, label='Combined Forecast', color='red', linestyle='--')
29
+
30
+ plt.title(title)
31
+ plt.xlabel('Date')
32
+ plt.ylabel('Price')
33
+ plt.legend()
34
+ st.pyplot(plt)
35
+
36
+ # Streamlit UI in Sidebar
37
+ st.sidebar.title("Input Parameters")
38
+ ticker = st.sidebar.text_input('Enter Ticker Symbol', 'AAPL')
39
+ start_date = st.sidebar.date_input('Select Start Date', value=pd.to_datetime('2020-01-01'))
40
+ end_date = st.sidebar.date_input('Select End Date', value=pd.to_datetime('2023-01-01'))
41
+ forecast_horizon = st.sidebar.number_input('Forecast Horizon (days)', min_value=1, value=180)
42
+ forecast_date = st.sidebar.date_input('Forecast Date', min_value=end_date, value=end_date + pd.Timedelta(days=180))
43
+
44
+ # User selects which forecasting models to use in Sidebar
45
+ options = st.sidebar.multiselect('Select forecasting models to use',
46
+ ['SARIMA', 'Linear Regression', 'TBATS', 'Random Forest'],
47
+ ['SARIMA', 'Linear Regression'])
48
+
49
+ if st.sidebar.button('Analyze'):
50
+ data = fetch_stock_data(ticker, start_date, end_date)
51
+ forecasts = {}
52
+
53
+ if 'SARIMA' in options:
54
+ forecasts['SARIMA'] = sarima_forecast(data, forecast_horizon)
55
+ if 'Linear Regression' in options:
56
+ forecasts['Linear Regression'] = linear_regression_forecast(data, forecast_horizon)
57
+ if 'TBATS' in options:
58
+ forecasts['TBATS'] = tbats_forecast(data, forecast_horizon)
59
+ if 'Random Forest' in options:
60
+ forecasts['Random Forest'] = random_forest_forecast(data, forecast_horizon)
61
+
62
+ plot_forecasts(data, forecasts, f"Forecasted Stock Prices for {ticker}")
63
+
64
+ # Output the forecasted price for the selected date, if available
65
+ forecast_date_str = forecast_date.strftime('%Y-%m-%d')
66
+ for model_name, forecast in forecasts.items():
67
+ if forecast_date_str in forecast.index:
68
+ st.write(f"Forecasted price by {model_name} on {forecast_date_str}: {forecast.loc[forecast_date_str]:.2f}")