jonathan-cristovao commited on
Commit
ef94dec
·
verified ·
1 Parent(s): b2ed964

Upload 10 files

Browse files
Files changed (10) hide show
  1. .gitattributes +35 -35
  2. .vscode/settings.json +9 -0
  3. README.md +13 -13
  4. app.py +14 -0
  5. model.py +52 -0
  6. model_page.py +49 -0
  7. plots.py +53 -0
  8. requirements.txt +9 -0
  9. stock_data_loader.py +19 -0
  10. view_page.py +50 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.vscode/settings.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "python.linting.enabled": true,
3
+ "python.linting.pylintEnabled": true,
4
+ "files.exclude": {
5
+ "**/*.pyc": {"when": "$(basename).py"},
6
+ "**/__pycache__": true,
7
+ "**/*.pytest_cache": true,
8
+ }
9
+ }
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: Real Time Stock Forecasting
3
- emoji: 📈
4
- colorFrom: green
5
- colorTo: yellow
6
- sdk: streamlit
7
- sdk_version: 1.36.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Stock Predict Lstm
3
+ emoji: 👁
4
+ colorFrom: blue
5
+ colorTo: gray
6
+ sdk: streamlit
7
+ sdk_version: 1.36.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from view_page import StockDashboard
3
+ from model_page import StockModelPage
4
+
5
+ def main():
6
+ st.set_page_config(layout='wide', page_title='Stock Analysis', page_icon=':dollar:')
7
+ page = st.sidebar.radio('Pages', ['View Page', 'Model Page'])
8
+ if page == 'View Page':
9
+ StockDashboard().run()
10
+ elif page == 'Model Page':
11
+ StockModelPage().run()
12
+
13
+ if __name__ == '__main__':
14
+ main()
model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from sklearn.preprocessing import MinMaxScaler
4
+ from keras.models import Sequential
5
+ from keras.layers import LSTM, Dense
6
+ import warnings
7
+ warnings.filterwarnings("ignore")
8
+
9
+ class Model:
10
+ def __init__(self, data):
11
+ self.data = data
12
+ self.scaler = MinMaxScaler(feature_range=(0, 1))
13
+ self.model = None
14
+
15
+ def prepare_data(self, look_back=1):
16
+ scaled_data = self.scaler.fit_transform(self.data['Close'].values.reshape(-1, 1))
17
+ def create_dataset(dataset):
18
+ X, Y = [], []
19
+ for i in range(len(dataset) - look_back):
20
+ a = dataset[i:(i + look_back), 0]
21
+ X.append(a)
22
+ Y.append(dataset[i + look_back, 0])
23
+ return np.array(X), np.array(Y)
24
+
25
+ X, Y = create_dataset(scaled_data)
26
+ X = np.reshape(X, (X.shape[0], 1, X.shape[1]))
27
+ return X, Y
28
+
29
+ def train_lstm(self, epochs=5, batch_size=1):
30
+ X, Y = self.prepare_data()
31
+ self.model = Sequential()
32
+ self.model.add(LSTM(50, input_shape=(1, 1)))
33
+ self.model.add(Dense(1))
34
+ self.model.compile(loss='mean_squared_error', optimizer='adam')
35
+ self.model.fit(X, Y, epochs=epochs, batch_size=batch_size, verbose=0)
36
+
37
+ def make_predictions(self):
38
+ X, _ = self.prepare_data()
39
+ predictions = self.model.predict(X)
40
+ predictions = self.scaler.inverse_transform(predictions)
41
+ return predictions
42
+
43
+ def forecast_future(self, days=5):
44
+ last_value = self.data['Close'].values[-1:].reshape(-1, 1)
45
+ last_scaled = self.scaler.transform(last_value)
46
+ future_predictions = []
47
+ for _ in range(days):
48
+ prediction = self.model.predict(last_scaled.reshape(1, 1, 1))[0]
49
+ future_predictions.append(prediction)
50
+ last_scaled = prediction
51
+ future_predictions = self.scaler.inverse_transform(future_predictions)
52
+ return future_predictions
model_page.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ from model import Model
4
+ from plots import Plots
5
+ from stock_data_loader import StockDataLoader
6
+
7
+ class StockModelPage:
8
+ def __init__(self):
9
+ self.tickers = ['NVDA', 'AAPL', 'GOOGL', 'MSFT', 'AMZN']
10
+ self.setup_sidebar()
11
+
12
+ def setup_sidebar(self):
13
+ self.ticker = st.sidebar.selectbox('Choose Stock Ticker', self.tickers)
14
+ self.start_date = st.sidebar.date_input('Start Date', value=pd.to_datetime('2010-01-01'))
15
+ self.end_date = st.sidebar.date_input('End Date', value=pd.to_datetime('today'))
16
+ self.load_button_clicked = st.sidebar.button('Load Data')
17
+
18
+ def load_data(self):
19
+ if self.load_button_clicked:
20
+ loader = StockDataLoader(self.ticker, self.start_date, self.end_date)
21
+ st.session_state['stock_data'] = loader.get_stock_data()
22
+ st.write("--------------------------------------------")
23
+ st.write(f"Data for {self.ticker} from {self.start_date} to {self.end_date} loaded successfully!")
24
+
25
+ def handle_model_training(self):
26
+ if 'stock_data' in st.session_state:
27
+ stock_data = st.session_state['stock_data']
28
+ if st.button('Train Model'):
29
+ st.write("Training Model...")
30
+ model = Model(stock_data)
31
+ model.train_lstm()
32
+ predictions = model.make_predictions()
33
+ future_predictions = model.forecast_future(days=5)
34
+ self.plot_predictions(stock_data, predictions, future_predictions)
35
+ else:
36
+ st.write("Click the button above to train the model.")
37
+ else:
38
+ st.write("--------------------------------------------")
39
+ st.write("Please load data before training the model.")
40
+
41
+ def plot_predictions(self, stock_data, predictions, future_predictions):
42
+ plot_instance = Plots(stock_data)
43
+ plot_instance.plot_predictions(predictions, future_predictions)
44
+
45
+ def run(self):
46
+ st.write("--------------------------------------------")
47
+ st.write(f'<div style="font-size:50px">🤖 Real-Time Stock Prediction', unsafe_allow_html=True)
48
+ self.load_data()
49
+ self.handle_model_training()
plots.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ import plotly.graph_objects as go
4
+ from plotly.subplots import make_subplots
5
+
6
+
7
+ class StockChart:
8
+ def __init__(self, data):
9
+ self.data = data
10
+ self.fig = make_subplots(rows=2, cols=1, vertical_spacing=0.01, shared_xaxes=True)
11
+
12
+ def add_price_chart(self):
13
+ self.fig.add_trace(go.Scatter(x=self.data.index, y=self.data['Open'], name='Open Price', marker_color='#1F77B4'), row=1, col=1)
14
+ self.fig.add_trace(go.Scatter(x=self.data.index, y=self.data['High'], name='High Price', marker_color='#9467BD'), row=1, col=1)
15
+ self.fig.add_trace(go.Scatter(x=self.data.index, y=self.data['Low'], name='Low Price', marker_color='#D62728'), row=1, col=1)
16
+ self.fig.add_trace(go.Scatter(x=self.data.index, y=self.data['Close'], name='Close Price', marker_color='#76B900'), row=1, col=1)
17
+
18
+
19
+ def add_oversold_overbought_lines(self):
20
+ self.fig.add_hline(y=30, line_dash='dash', line_color='limegreen', line_width=1, row=1, col=1)
21
+ self.fig.add_hline(y=70, line_dash='dash', line_color='red', line_width=1, row=1, col=1)
22
+ self.fig.update_yaxes(title_text='RSI Score', row=1, col=1)
23
+
24
+ def add_volume_chart(self):
25
+ colors = ['#9C1F0B' if row['Open'] - row['Close'] >= 0 else '#2B8308' for index, row in self.data.iterrows()]
26
+ self.fig.add_trace(go.Bar(x=self.data.index, y=self.data['Volume'], showlegend=False, marker_color=colors), row=2, col=1)
27
+
28
+ def render_chart(self):
29
+ self.fig.update_layout(title='Historical Price and Volume', height=500, margin=dict(l=0, r=10, b=10, t=25))
30
+ st.plotly_chart(self.fig, use_container_width=True)
31
+
32
+ class Plots:
33
+ def __init__(self, data):
34
+ self.data = data
35
+
36
+ def plot_predictions(self, predictions, future_predictions):
37
+
38
+ predicted_dates = self.data.index[-len(predictions):]
39
+ future_dates = pd.date_range(start=self.data.index[-1] + pd.Timedelta(days=1), periods=len(future_predictions), freq='B')
40
+ predictions = [float(val) for val in predictions if pd.notna(val)]
41
+ future_predictions = [float(val) for val in future_predictions if pd.notna(val)]
42
+
43
+ fig = make_subplots(rows=1, cols=1)
44
+ fig.add_trace(go.Scatter(x=self.data.index, y=self.data['Close'], mode='lines', name='Actual Stock Prices', marker_color='blue'))
45
+ fig.add_trace(go.Scatter(x=predicted_dates, y=predictions, mode='lines', name='LSTM Predicted Prices', marker_color='red', line=dict(dash='dash')))
46
+ fig.add_trace(go.Scatter(x=future_dates, y=future_predictions, mode='lines', name='Future Predictions', marker_color='green', line=dict(dash='dot')))
47
+
48
+ fig.update_layout(title='Comparison of Actual, Predicted, and Future Stock Prices', xaxis_title='Date', yaxis_title='Price', legend_title='Legend', height=500)
49
+ st.plotly_chart(fig, use_container_width=True)
50
+
51
+
52
+
53
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ seaborn
4
+ matplotlib
5
+ keras
6
+ tensorflow
7
+ scikit-learn
8
+ yfinance
9
+ plotly
stock_data_loader.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import yfinance as yf
3
+
4
+ import warnings
5
+ warnings.filterwarnings("ignore")
6
+
7
+ class StockDataLoader:
8
+ def __init__(self, ticker, start_date, end_date):
9
+ self.ticker = ticker
10
+ self.start_date = start_date
11
+ self.end_date = end_date
12
+
13
+ def get_stock_data(self):
14
+ stock = yf.Ticker(self.ticker)
15
+ stock_data = stock.history(start=self.start_date, end=self.end_date)
16
+ stock_data.reset_index(inplace=True)
17
+ stock_data['Date'] = pd.to_datetime(stock_data['Date'])
18
+ stock_data.set_index('Date', inplace=True)
19
+ return stock_data
view_page.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stock_data_loader import StockDataLoader
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import yfinance as yf
5
+ from datetime import datetime
6
+ from plots import Plots, StockChart
7
+
8
+ class StockDashboard:
9
+ def __init__(self):
10
+ self.tickers = ['NVDA', 'AAPL', 'GOOGL', 'MSFT', 'AMZN']
11
+ self.period_map = {'all': 'max','1m': '1mo', '6m': '6mo', '1y': '1y'}
12
+
13
+ def render_sidebar(self):
14
+ st.sidebar.header("Choose your filter:")
15
+ self.ticker = st.sidebar.selectbox('Choose Ticker', options=self.tickers, help='Select a ticker')
16
+ self.selected_range = st.sidebar.selectbox('Select Period', options=list(self.period_map.keys()))
17
+
18
+ def load_data(self):
19
+ self.yf_data = yf.Ticker(self.ticker)
20
+ self.df_history = self.yf_data.history(period=self.period_map[self.selected_range])
21
+ self.current_price = self.yf_data.info.get('currentPrice', 'N/A')
22
+ self.previous_close = self.yf_data.info.get('previousClose', 'N/A')
23
+
24
+ def display_header(self):
25
+ company_name = self.yf_data.info['shortName']
26
+ symbol = self.yf_data.info['symbol']
27
+ st.subheader(f'{company_name} ({symbol}) 💰')
28
+ st.divider()
29
+ if self.current_price != 'N/A' and self.previous_close != 'N/A':
30
+ price_change = self.current_price - self.previous_close
31
+ price_change_ratio = (abs(price_change) / self.previous_close * 100)
32
+ price_change_direction = "+" if price_change > 0 else "-"
33
+ st.metric(label='Current Price', value=f"{self.current_price:.2f}",
34
+ delta=f"{price_change:.2f} ({price_change_direction}{price_change_ratio:.2f}%)")
35
+
36
+ def plot_data(self):
37
+ chart = StockChart(self.df_history)
38
+ chart.add_price_chart()
39
+ chart.add_oversold_overbought_lines()
40
+ chart.add_volume_chart()
41
+ chart.render_chart()
42
+
43
+ def run(self):
44
+ st.write("--------------------------------------------")
45
+ self.render_sidebar()
46
+ self.load_data()
47
+ self.display_header()
48
+ self.plot_data()
49
+
50
+