Johan713 commited on
Commit
a1bc0de
·
verified ·
1 Parent(s): b18bfd5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +587 -0
app.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import yfinance as yf
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
+ import plotly.express as px
8
+ from datetime import datetime, timedelta
9
+ from statsmodels.tsa.statespace.sarimax import SARIMAX
10
+ from prophet import Prophet
11
+ from sklearn.ensemble import RandomForestRegressor
12
+ import xgboost as xgb
13
+ from sklearn.preprocessing import StandardScaler, MinMaxScaler
14
+ from keras.models import Sequential
15
+ from keras.optimizers import Adam
16
+ from keras.layers import Dense, LSTM
17
+ from sklearn.model_selection import train_test_split, TimeSeriesSplit
18
+ from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error
19
+ import requests
20
+ from bs4 import BeautifulSoup
21
+ import base64
22
+ import warnings
23
+ from ta.trend import SMAIndicator, EMAIndicator
24
+ from ta.momentum import RSIIndicator
25
+ from ta.volatility import BollingerBands
26
+ from pmdarima import auto_arima
27
+ warnings.filterwarnings('ignore')
28
+
29
+ # List of companies (display name, ticker symbol)
30
+ COMPANIES = [
31
+ ("Apple", "AAPL"), ("Microsoft", "MSFT"), ("Amazon", "AMZN"), ("Google", "GOOGL"),
32
+ ("Facebook", "FB"), ("Tesla", "TSLA"), ("NVIDIA", "NVDA"), ("JPMorgan Chase", "JPM"),
33
+ ("Johnson & Johnson", "JNJ"), ("Visa", "V"), ("Procter & Gamble", "PG"), ("UnitedHealth", "UNH"),
34
+ ("Home Depot", "HD"), ("Mastercard", "MA"), ("Bank of America", "BAC"), ("Disney", "DIS"),
35
+ ("Netflix", "NFLX"), ("Coca-Cola", "KO"), ("Pepsi", "PEP"), ("Adobe", "ADBE")
36
+ ]
37
+
38
+ class StockPredictor:
39
+ def __init__(self, data, model_type):
40
+ self.data = data
41
+ self.model_type = model_type
42
+ self.model = None
43
+ self.scaler = None
44
+ self.lstm_scaler = None
45
+
46
+ def preprocess_data(self):
47
+ self.data['Date'] = pd.to_datetime(self.data.index)
48
+ self.data = self.data.reset_index(drop=True)
49
+
50
+ # Enhanced Feature Engineering
51
+ self.data['DayOfWeek'] = self.data['Date'].dt.dayofweek
52
+ self.data['Month'] = self.data['Date'].dt.month
53
+ self.data['Year'] = self.data['Date'].dt.year
54
+ self.data['IsMonthEnd'] = self.data['Date'].dt.is_month_end.astype(int)
55
+
56
+ # Technical Indicators
57
+ self.data['SMA_20'] = SMAIndicator(close=self.data['Close'], window=20).sma_indicator()
58
+ self.data['EMA_20'] = EMAIndicator(close=self.data['Close'], window=20).ema_indicator()
59
+ self.data['RSI'] = RSIIndicator(close=self.data['Close']).rsi()
60
+ bb = BollingerBands(close=self.data['Close'], window=20, window_dev=2)
61
+ self.data['BB_High'] = bb.bollinger_hband()
62
+ self.data['BB_Low'] = bb.bollinger_lband()
63
+
64
+ # Log returns
65
+ self.data['LogReturn'] = np.log(self.data['Close'] / self.data['Close'].shift(1))
66
+
67
+ # Handle NaN values
68
+ self.data.dropna(inplace=True)
69
+
70
+ # Define features for the model
71
+ self.features = ['Open', 'High', 'Low', 'Close', 'Volume', 'SMA_20', 'EMA_20', 'RSI', 'BB_High', 'BB_Low', 'LogReturn', 'DayOfWeek', 'Month', 'Year', 'IsMonthEnd']
72
+
73
+ # Apply scaling for XGBoost and RandomForest
74
+ if self.model_type in ['XGBoost', 'RandomForest']:
75
+ self.scaler = StandardScaler()
76
+ self.data[self.features] = self.scaler.fit_transform(self.data[self.features])
77
+
78
+ # Additional preprocessing for LSTM
79
+ if self.model_type == 'LSTM':
80
+ self.lstm_scaler = MinMaxScaler(feature_range=(0, 1))
81
+ self.data['Scaled_Close'] = self.lstm_scaler.fit_transform(self.data[['Close']])
82
+
83
+ def create_lstm_dataset(self, look_back=60):
84
+ scaled_data = self.data['Scaled_Close'].values
85
+ x, y = [], []
86
+ for i in range(look_back, len(scaled_data)):
87
+ x.append(scaled_data[i-look_back:i])
88
+ y.append(scaled_data[i])
89
+ return np.array(x), np.array(y)
90
+
91
+ def train_model(self):
92
+ try:
93
+ if self.model_type == 'LSTM':
94
+ x, y = self.create_lstm_dataset()
95
+ x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, shuffle=False)
96
+
97
+ model = Sequential([
98
+ LSTM(50, return_sequences=True, input_shape=(x_train.shape[1], 1)),
99
+ LSTM(50, return_sequences=False),
100
+ Dense(25),
101
+ Dense(1)
102
+ ])
103
+
104
+ model.compile(optimizer=Adam(learning_rate=0.001), loss='mean_squared_error')
105
+ model.fit(x_train, y_train, epochs=50, batch_size=32, validation_data=(x_test, y_test), verbose=0)
106
+
107
+ self.model = model
108
+
109
+ elif self.model_type == 'SARIMA':
110
+ train_data = self.data['Close']
111
+ # Use auto_arima to find optimal parameters
112
+ from pmdarima import auto_arima
113
+ auto_model = auto_arima(train_data, start_p=1, start_q=1, max_p=3, max_q=3, m=12,
114
+ start_P=0, seasonal=True, d=1, D=1, trace=True,
115
+ error_action='ignore', suppress_warnings=True, stepwise=True)
116
+
117
+ self.model = SARIMAX(train_data, order=auto_model.order, seasonal_order=auto_model.seasonal_order)
118
+ self.model = self.model.fit(disp=False)
119
+
120
+ elif self.model_type == 'Prophet':
121
+ df = self.data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'})
122
+ self.model = Prophet(
123
+ changepoint_prior_scale=0.05,
124
+ seasonality_prior_scale=10,
125
+ holidays_prior_scale=10,
126
+ daily_seasonality=True,
127
+ weekly_seasonality=True,
128
+ yearly_seasonality=True
129
+ )
130
+ for feature in ['SMA_20', 'EMA_20', 'RSI', 'BB_High', 'BB_Low']:
131
+ self.model.add_regressor(feature)
132
+ df[feature] = self.data[feature]
133
+ self.model.fit(df)
134
+
135
+ elif self.model_type == 'XGBoost':
136
+ X = self.data[self.features]
137
+ y = self.data['Close']
138
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False)
139
+
140
+ param_grid = {
141
+ 'max_depth': [3, 5],
142
+ 'learning_rate': [0.01, 0.1],
143
+ 'n_estimators': [100, 200]
144
+ }
145
+ model = xgb.XGBRegressor(objective='reg:squarederror')
146
+ grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=3, n_jobs=-1, verbose=0)
147
+ grid_search.fit(X_train, y_train)
148
+
149
+ self.model = grid_search.best_estimator_
150
+
151
+ elif self.model_type == 'RandomForest':
152
+ X = self.data[self.features]
153
+ y = self.data['Close']
154
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False)
155
+
156
+ param_grid = {
157
+ 'n_estimators': [100, 200],
158
+ 'max_depth': [10, 20]
159
+ }
160
+ model = RandomForestRegressor(random_state=42)
161
+ grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=3, n_jobs=-1, verbose=0)
162
+ grid_search.fit(X_train, y_train)
163
+
164
+ self.model = grid_search.best_estimator_
165
+
166
+ return True
167
+
168
+ except Exception as e:
169
+ print(f"Error training {self.model_type} model: {str(e)}")
170
+ return False
171
+
172
+ def predict(self, days=30):
173
+ try:
174
+ if self.model_type == 'LSTM':
175
+ last_sequence = self.data['Scaled_Close'].values[-60:].reshape(1, 60, 1)
176
+ predictions = []
177
+ for _ in range(days):
178
+ pred = self.model.predict(last_sequence)
179
+ predictions.append(pred[0, 0])
180
+ last_sequence = np.roll(last_sequence, -1, axis=1)
181
+ last_sequence[0, -1, 0] = pred[0, 0]
182
+ return self.lstm_scaler.inverse_transform(np.array(predictions).reshape(-1, 1)).flatten()
183
+
184
+ elif self.model_type == 'SARIMA':
185
+ forecast = self.model.get_forecast(steps=days)
186
+ return forecast.predicted_mean
187
+
188
+ elif self.model_type == 'Prophet':
189
+ future = self.model.make_future_dataframe(periods=days)
190
+ for feature in ['SMA_20', 'EMA_20', 'RSI', 'BB_High', 'BB_Low']:
191
+ future[feature] = self.data[feature].iloc[-1] # Use last known value
192
+ forecast = self.model.predict(future)
193
+ return forecast['yhat'][-days:].values
194
+
195
+ elif self.model_type in ['XGBoost', 'RandomForest']:
196
+ last_data = self.data[self.features].iloc[-1:].values
197
+ predictions = []
198
+ for _ in range(days):
199
+ pred = self.model.predict(last_data)
200
+ predictions.append(pred[0])
201
+ # Update last_data for next prediction
202
+ last_data = np.roll(last_data, -1, axis=1)
203
+ last_data[0, -5] = pred[0] # Assuming 'Close' is the 5th from last feature
204
+ return predictions
205
+
206
+ except Exception as e:
207
+ print(f"Error predicting with {self.model_type} model: {str(e)}")
208
+ return None
209
+
210
+ def evaluate_model(self, test_data):
211
+ predictions = self.predict(len(test_data))
212
+ mse = mean_squared_error(test_data['Close'], predictions)
213
+ mape = mean_absolute_percentage_error(test_data['Close'], predictions)
214
+ rmse = np.sqrt(mse)
215
+ return mse, mape, rmse
216
+
217
+ def fetch_stock_data(ticker):
218
+ try:
219
+ end_date = datetime.now()
220
+ start_date = datetime(2000, 1, 1)
221
+ data = yf.download(ticker, start=start_date, end=end_date)
222
+ return data
223
+ except Exception as e:
224
+ st.error(f"Error fetching data for {ticker}: {str(e)}")
225
+ return None
226
+
227
+ def create_test_plot(train_data, test_data, predicted_data, company_name):
228
+ fig = go.Figure()
229
+
230
+ fig.add_trace(go.Scatter(
231
+ x=train_data.index,
232
+ y=train_data['Close'],
233
+ mode='lines',
234
+ name='Training Data',
235
+ line=dict(color='blue')
236
+ ))
237
+
238
+ fig.add_trace(go.Scatter(
239
+ x=test_data.index,
240
+ y=test_data['Close'],
241
+ mode='lines',
242
+ name='Actual (Test) Data',
243
+ line=dict(color='green')
244
+ ))
245
+
246
+ if predicted_data is not None:
247
+ fig.add_trace(go.Scatter(
248
+ x=test_data.index, # Align predicted data with test data
249
+ y=predicted_data['yhat'][-len(test_data):],
250
+ mode='lines',
251
+ name='Predicted Data',
252
+ line=dict(color='red', dash='dash')
253
+ ))
254
+
255
+ fig.update_layout(
256
+ title=f'{company_name} Stock Price Prediction (Test Model)',
257
+ xaxis_title='Date',
258
+ yaxis_title='Close Price',
259
+ template='plotly_dark',
260
+ hovermode='x unified',
261
+ xaxis_rangeslider_visible=True,
262
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
263
+ )
264
+ return fig
265
+
266
+ def create_prediction_plot(data, predicted_data, company_name):
267
+ fig = go.Figure()
268
+
269
+ fig.add_trace(go.Scatter(
270
+ x=data.index,
271
+ y=data['Close'],
272
+ mode='lines',
273
+ name='Historical Data',
274
+ line=dict(color='cyan')
275
+ ))
276
+
277
+ if predicted_data is not None:
278
+ future_dates = pd.date_range(start=data.index[-1] + pd.Timedelta(days=1), periods=len(predicted_data))
279
+ fig.add_trace(go.Scatter(
280
+ x=future_dates,
281
+ y=predicted_data['yhat'],
282
+ mode='lines',
283
+ name='Predicted Data',
284
+ line=dict(color='yellow')
285
+ ))
286
+
287
+ fig.update_layout(
288
+ title=f'{company_name} Stock Price Prediction',
289
+ xaxis_title='Date',
290
+ yaxis_title='Close Price',
291
+ template='plotly_dark',
292
+ hovermode='x unified',
293
+ xaxis_rangeslider_visible=True,
294
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
295
+ )
296
+ return fig
297
+
298
+ def create_candlestick_plot(data, company_name):
299
+ fig = go.Figure(data=[go.Candlestick(x=data.index,
300
+ open=data['Open'],
301
+ high=data['High'],
302
+ low=data['Low'],
303
+ close=data['Close'])])
304
+ fig.update_layout(
305
+ title=f'{company_name} Stock Price Candlestick Chart',
306
+ xaxis_title='Date',
307
+ yaxis_title='Price',
308
+ template='plotly_dark',
309
+ xaxis_rangeslider_visible=True
310
+ )
311
+ return fig
312
+
313
+ def fetch_news(company_name):
314
+ try:
315
+ url = f"https://news.google.com/rss/search?q={company_name}+stock&hl=en-US&gl=US&ceid=US:en"
316
+ response = requests.get(url)
317
+ soup = BeautifulSoup(response.content, features='xml')
318
+ news_items = soup.findAll('item')
319
+
320
+ news = []
321
+ for item in news_items[:5]:
322
+ news.append({
323
+ 'title': item.title.text,
324
+ 'link': item.link.text,
325
+ 'pubDate': item.pubDate.text
326
+ })
327
+
328
+ return news
329
+ except Exception as e:
330
+ st.error(f"Error fetching news: {str(e)}")
331
+ return []
332
+
333
+ def get_table_download_link(df):
334
+ csv = df.to_csv(index=False)
335
+ b64 = base64.b64encode(csv.encode()).decode()
336
+ href = f'<a href="data:file/csv;base64,{b64}" download="stock_data.csv">Download CSV File</a>'
337
+ return href
338
+
339
+ def main():
340
+ st.set_page_config(page_title="Stock Price Predictor", layout="wide")
341
+ st.title("Advanced Stock Price Predictor using Prophet")
342
+
343
+ st.sidebar.title("Options")
344
+ app_mode = st.sidebar.selectbox("Choose the app mode", ["Test Model", "Predict Stock Prices"])
345
+
346
+ if app_mode == "Test Model":
347
+ test_model()
348
+ else:
349
+ predict_stock_prices()
350
+
351
+ def test_model():
352
+ st.header("Test Prophet Model")
353
+
354
+ col1, col2 = st.columns(2)
355
+
356
+ with col1:
357
+ company = st.selectbox("Select Company", [company for company, _ in COMPANIES])
358
+ test_split = st.slider("Test Data Split", 0.1, 0.5, 0.2, 0.05)
359
+
360
+ if st.button("Train and Test Model"):
361
+ with st.spinner("Fetching data and training model..."):
362
+ company_name, ticker = next((name, symbol) for name, symbol in COMPANIES if name == company)
363
+
364
+ data = fetch_stock_data(ticker)
365
+
366
+ if data is not None:
367
+ st.subheader("Stock Data Information")
368
+ st.write(data.info())
369
+ st.write(data.describe())
370
+ st.dataframe(data.head())
371
+
372
+ st.markdown(get_table_download_link(data), unsafe_allow_html=True)
373
+
374
+ split_index = int(len(data) * (1 - test_split))
375
+ train_data = data.iloc[:split_index]
376
+ test_data = data.iloc[split_index:]
377
+
378
+ predictor = StockPredictor(train_data)
379
+ predictor.preprocess_data()
380
+ if predictor.train_model():
381
+ test_pred = predictor.predict(days=len(test_data))
382
+
383
+ if test_pred is not None:
384
+ mse, mape, rmse = predictor.evaluate_model(test_data)
385
+ accuracy = 100 - mape * 100
386
+
387
+ st.subheader("Model Performance")
388
+ st.metric("Prediction Accuracy", f"{accuracy:.2f}%")
389
+ st.metric("Mean Squared Error", f"{mse:.4f}")
390
+ st.metric("Root Mean Squared Error", f"{rmse:.4f}")
391
+
392
+ plot = create_test_plot(predictor.data, test_data, test_pred, company_name)
393
+ st.plotly_chart(plot, use_container_width=True)
394
+ else:
395
+ st.error("Failed to train the Prophet model. Please try a different dataset.")
396
+
397
+ def predict_stock_prices():
398
+ st.header("Predict Stock Prices")
399
+
400
+ col1, col2 = st.columns(2)
401
+
402
+ with col1:
403
+ company = st.selectbox("Select Company", [company for company, _ in COMPANIES])
404
+ days_to_predict = st.slider("Days to Predict", 1, 365, 30)
405
+
406
+ if st.button("Predict Stock Prices"):
407
+ with st.spinner("Fetching data and making predictions..."):
408
+ company_name, ticker = next((name, symbol) for name, symbol in COMPANIES if name == company)
409
+
410
+ data = fetch_stock_data(ticker)
411
+
412
+ if data is not None:
413
+ st.subheader("Stock Data Information")
414
+ st.write(data.info())
415
+ st.write(data.describe())
416
+ st.dataframe(data.head())
417
+
418
+ st.markdown(get_table_download_link(data), unsafe_allow_html=True)
419
+
420
+ predictor = StockPredictor(data)
421
+ predictor.preprocess_data()
422
+ if predictor.train_model():
423
+ predictions = predictor.predict(days=days_to_predict)
424
+
425
+ if predictions is not None:
426
+ plot = create_prediction_plot(data, predictions, company_name)
427
+ st.plotly_chart(plot, use_container_width=True)
428
+
429
+ candlestick_plot = create_candlestick_plot(data, company_name)
430
+ st.plotly_chart(candlestick_plot, use_container_width=True)
431
+
432
+ st.subheader("Predicted Prices")
433
+ pred_df = predictions[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(days_to_predict)
434
+ pred_df.columns = ['Date', 'Predicted Price', 'Lower Bound', 'Upper Bound']
435
+ st.dataframe(pred_df)
436
+
437
+ news = fetch_news(company_name)
438
+ st.subheader("Latest News")
439
+ for item in news:
440
+ st.markdown(f"[{item['title']}]({item['link']}) ({item['pubDate']})")
441
+ else:
442
+ st.error("Failed to train the Prophet model. Please try a different dataset.")
443
+
444
+ def explore_data():
445
+ st.header("Explore Stock Data")
446
+
447
+ col1, col2 = st.columns(2)
448
+
449
+ with col1:
450
+ company = st.selectbox("Select Company", [company for company, _ in COMPANIES])
451
+
452
+ with col2:
453
+ period = st.selectbox("Select Time Period", ["1mo", "3mo", "6mo", "1y", "2y", "5y", "max"])
454
+
455
+ company_name, ticker = next((name, symbol) for name, symbol in COMPANIES if name == company)
456
+
457
+ if st.button("Explore Data"):
458
+ with st.spinner("Fetching and analyzing data..."):
459
+ data = yf.download(ticker, period=period)
460
+
461
+ if data is not None and not data.empty:
462
+ st.subheader(f"{company_name} Stock Data")
463
+
464
+ # Create tabs for different visualizations
465
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(["Price History", "OHLC", "Technical Indicators", "Volume & Turnover", "Statistics"])
466
+
467
+ with tab1:
468
+ # Stock Price History
469
+ fig = go.Figure()
470
+ fig.add_trace(go.Scatter(x=data.index, y=data['Open'], mode='lines', name='Open'))
471
+ fig.add_trace(go.Scatter(x=data.index, y=data['High'], mode='lines', name='High'))
472
+ fig.add_trace(go.Scatter(x=data.index, y=data['Low'], mode='lines', name='Low'))
473
+ fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='Close'))
474
+
475
+ # Add rolling mean and standard deviation
476
+ data['Rolling_Mean'] = data['Close'].rolling(window=20).mean()
477
+ data['Rolling_Std'] = data['Close'].rolling(window=20).std()
478
+ fig.add_trace(go.Scatter(x=data.index, y=data['Rolling_Mean'], mode='lines', name='20-day Rolling Mean', line=dict(dash='dash')))
479
+ fig.add_trace(go.Scatter(x=data.index, y=data['Rolling_Std'], mode='lines', name='20-day Rolling Std', line=dict(dash='dot')))
480
+
481
+ fig.update_layout(title=f"{company_name} Stock Price History",
482
+ xaxis_title="Date",
483
+ yaxis_title="Price",
484
+ hovermode="x unified",
485
+ template="plotly_dark")
486
+ st.plotly_chart(fig, use_container_width=True)
487
+
488
+ with tab2:
489
+ # OHLC Chart
490
+ ohlc_fig = go.Figure(data=[go.Candlestick(x=data.index,
491
+ open=data['Open'],
492
+ high=data['High'],
493
+ low=data['Low'],
494
+ close=data['Close'])])
495
+ ohlc_fig.update_layout(title=f"{company_name} OHLC Chart",
496
+ xaxis_title="Date",
497
+ yaxis_title="Price",
498
+ template="plotly_dark",
499
+ xaxis_rangeslider_visible=False)
500
+ st.plotly_chart(ohlc_fig, use_container_width=True)
501
+
502
+ with tab3:
503
+ # Technical Indicators
504
+ data['SMA_20'] = SMAIndicator(close=data['Close'], window=20).sma_indicator()
505
+ data['EMA_20'] = EMAIndicator(close=data['Close'], window=20).ema_indicator()
506
+ bb = BollingerBands(close=data['Close'], window=20, window_dev=2)
507
+ data['BB_High'] = bb.bollinger_hband()
508
+ data['BB_Low'] = bb.bollinger_lband()
509
+ data['RSI'] = RSIIndicator(close=data['Close']).rsi()
510
+
511
+ fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
512
+ vertical_spacing=0.03,
513
+ subplot_titles=("Price and Indicators", "RSI"),
514
+ row_heights=[0.7, 0.3])
515
+
516
+ fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='Close'), row=1, col=1)
517
+ fig.add_trace(go.Scatter(x=data.index, y=data['SMA_20'], mode='lines', name='SMA 20'), row=1, col=1)
518
+ fig.add_trace(go.Scatter(x=data.index, y=data['EMA_20'], mode='lines', name='EMA 20'), row=1, col=1)
519
+ fig.add_trace(go.Scatter(x=data.index, y=data['BB_High'], mode='lines', name='BB High'), row=1, col=1)
520
+ fig.add_trace(go.Scatter(x=data.index, y=data['BB_Low'], mode='lines', name='BB Low'), row=1, col=1)
521
+
522
+ fig.add_trace(go.Scatter(x=data.index, y=data['RSI'], mode='lines', name='RSI'), row=2, col=1)
523
+ fig.add_hline(y=70, line_dash="dash", line_color="red", row=2, col=1)
524
+ fig.add_hline(y=30, line_dash="dash", line_color="green", row=2, col=1)
525
+
526
+ fig.update_layout(height=800, title_text=f"{company_name} Technical Indicators",
527
+ hovermode="x unified", template="plotly_dark")
528
+ fig.update_xaxes(rangeslider_visible=False, row=2, col=1)
529
+ fig.update_yaxes(title_text="Price", row=1, col=1)
530
+ fig.update_yaxes(title_text="RSI", row=2, col=1)
531
+
532
+ st.plotly_chart(fig, use_container_width=True)
533
+
534
+ with tab4:
535
+ # Volume and Turnover
536
+ fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
537
+ vertical_spacing=0.03,
538
+ subplot_titles=("Volume", "Turnover (if available)"),
539
+ row_heights=[0.5, 0.5])
540
+
541
+ fig.add_trace(go.Bar(x=data.index, y=data['Volume'], name='Volume'), row=1, col=1)
542
+
543
+ if 'Turnover' in data.columns:
544
+ fig.add_trace(go.Bar(x=data.index, y=data['Turnover'], name='Turnover'), row=2, col=1)
545
+ else:
546
+ fig.add_annotation(text="Turnover data not available", xref="paper", yref="paper", x=0.5, y=0.25, showarrow=False)
547
+
548
+ fig.update_layout(height=600, title_text=f"{company_name} Volume and Turnover",
549
+ hovermode="x unified", template="plotly_dark")
550
+ fig.update_xaxes(rangeslider_visible=False, row=2, col=1)
551
+ fig.update_yaxes(title_text="Volume", row=1, col=1)
552
+ fig.update_yaxes(title_text="Turnover", row=2, col=1)
553
+
554
+ st.plotly_chart(fig, use_container_width=True)
555
+
556
+ with tab5:
557
+ # Display key statistics
558
+ st.subheader("Key Statistics")
559
+ col1, col2, col3 = st.columns(3)
560
+ with col1:
561
+ st.metric("Current Price", f"${data['Close'].iloc[-1]:.2f}")
562
+ st.metric("52 Week High", f"${data['High'].max():.2f}")
563
+ with col2:
564
+ st.metric("Volume", f"{data['Volume'].iloc[-1]:,}")
565
+ st.metric("52 Week Low", f"${data['Low'].min():.2f}")
566
+ with col3:
567
+ returns = (data['Close'].pct_change() * 100).dropna()
568
+ st.metric("Avg Daily Return", f"{returns.mean():.2f}%")
569
+ st.metric("Return Volatility", f"{returns.std():.2f}%")
570
+
571
+ # Correlation Heatmap
572
+ correlation = data[['Open', 'High', 'Low', 'Close', 'Volume']].corr()
573
+ heatmap_fig = px.imshow(correlation, text_auto=True, aspect="auto", color_continuous_scale='Viridis')
574
+ heatmap_fig.update_layout(title="Correlation Heatmap", template="plotly_dark")
575
+ st.plotly_chart(heatmap_fig, use_container_width=True)
576
+
577
+ # Display news
578
+ st.subheader("Latest News")
579
+ news = fetch_news(company_name)
580
+ for item in news:
581
+ st.markdown(f"[{item['title']}]({item['link']}) ({item['pubDate']})")
582
+
583
+ else:
584
+ st.error("Failed to fetch data. Please try again.")
585
+
586
+ if __name__ == "__main__":
587
+ main()