Johan713 commited on
Commit
38753df
·
verified ·
1 Parent(s): ff6b6c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -36,7 +36,7 @@ COMPANIES = [
36
  ]
37
 
38
  class StockPredictor:
39
- def __init__(self, data, model_type):
40
  self.data = data
41
  self.model_type = model_type
42
  self.model = None
@@ -349,7 +349,7 @@ def main():
349
  predict_stock_prices()
350
 
351
  def test_model():
352
- st.header("Test Prophet Model")
353
 
354
  col1, col2 = st.columns(2)
355
 
@@ -357,6 +357,9 @@ def test_model():
357
  company = st.selectbox("Select Company", [company for company, _ in COMPANIES])
358
  test_split = st.slider("Test Data Split", 0.1, 0.5, 0.2, 0.05)
359
 
 
 
 
360
  if st.button("Train and Test Model"):
361
  with st.spinner("Fetching data and training model..."):
362
  company_name, ticker = next((name, symbol) for name, symbol in COMPANIES if name == company)
@@ -375,7 +378,7 @@ def test_model():
375
  train_data = data.iloc[:split_index]
376
  test_data = data.iloc[split_index:]
377
 
378
- predictor = StockPredictor(train_data) # Updated: removed model_type argument
379
  predictor.preprocess_data()
380
  if predictor.train_model():
381
  test_pred = predictor.predict(days=len(test_data))
@@ -392,7 +395,7 @@ def test_model():
392
  plot = create_test_plot(predictor.data, test_data, test_pred, company_name)
393
  st.plotly_chart(plot, use_container_width=True)
394
  else:
395
- st.error("Failed to train the Prophet model. Please try a different dataset.")
396
 
397
  def predict_stock_prices():
398
  st.header("Predict Stock Prices")
@@ -403,6 +406,9 @@ def predict_stock_prices():
403
  company = st.selectbox("Select Company", [company for company, _ in COMPANIES])
404
  days_to_predict = st.slider("Days to Predict", 1, 365, 30)
405
 
 
 
 
406
  if st.button("Predict Stock Prices"):
407
  with st.spinner("Fetching data and making predictions..."):
408
  company_name, ticker = next((name, symbol) for name, symbol in COMPANIES if name == company)
@@ -417,7 +423,7 @@ def predict_stock_prices():
417
 
418
  st.markdown(get_table_download_link(data), unsafe_allow_html=True)
419
 
420
- predictor = StockPredictor(data) # Updated: removed model_type argument
421
  predictor.preprocess_data()
422
  if predictor.train_model():
423
  predictions = predictor.predict(days=days_to_predict)
@@ -439,7 +445,7 @@ def predict_stock_prices():
439
  for item in news:
440
  st.markdown(f"[{item['title']}]({item['link']}) ({item['pubDate']})")
441
  else:
442
- st.error("Failed to train the Prophet model. Please try a different dataset.")
443
 
444
  def explore_data():
445
  st.header("Explore Stock Data")
 
36
  ]
37
 
38
  class StockPredictor:
39
+ def __init__(self, data, model_type='Prophet'):
40
  self.data = data
41
  self.model_type = model_type
42
  self.model = None
 
349
  predict_stock_prices()
350
 
351
  def test_model():
352
+ st.header("Test Stock Prediction Model")
353
 
354
  col1, col2 = st.columns(2)
355
 
 
357
  company = st.selectbox("Select Company", [company for company, _ in COMPANIES])
358
  test_split = st.slider("Test Data Split", 0.1, 0.5, 0.2, 0.05)
359
 
360
+ with col2:
361
+ model_type = st.selectbox("Select Model Type", ['Prophet', 'LSTM', 'SARIMA', 'XGBoost', 'RandomForest'])
362
+
363
  if st.button("Train and Test Model"):
364
  with st.spinner("Fetching data and training model..."):
365
  company_name, ticker = next((name, symbol) for name, symbol in COMPANIES if name == company)
 
378
  train_data = data.iloc[:split_index]
379
  test_data = data.iloc[split_index:]
380
 
381
+ predictor = StockPredictor(train_data, model_type) # Updated: added model_type argument
382
  predictor.preprocess_data()
383
  if predictor.train_model():
384
  test_pred = predictor.predict(days=len(test_data))
 
395
  plot = create_test_plot(predictor.data, test_data, test_pred, company_name)
396
  st.plotly_chart(plot, use_container_width=True)
397
  else:
398
+ st.error(f"Failed to train the {model_type} model. Please try a different dataset or model type.")
399
 
400
  def predict_stock_prices():
401
  st.header("Predict Stock Prices")
 
406
  company = st.selectbox("Select Company", [company for company, _ in COMPANIES])
407
  days_to_predict = st.slider("Days to Predict", 1, 365, 30)
408
 
409
+ with col2:
410
+ model_type = st.selectbox("Select Model Type", ['Prophet', 'LSTM', 'SARIMA', 'XGBoost', 'RandomForest'])
411
+
412
  if st.button("Predict Stock Prices"):
413
  with st.spinner("Fetching data and making predictions..."):
414
  company_name, ticker = next((name, symbol) for name, symbol in COMPANIES if name == company)
 
423
 
424
  st.markdown(get_table_download_link(data), unsafe_allow_html=True)
425
 
426
+ predictor = StockPredictor(data, model_type) # Updated: added model_type argument
427
  predictor.preprocess_data()
428
  if predictor.train_model():
429
  predictions = predictor.predict(days=days_to_predict)
 
445
  for item in news:
446
  st.markdown(f"[{item['title']}]({item['link']}) ({item['pubDate']})")
447
  else:
448
+ st.error(f"Failed to train the {model_type} model. Please try a different dataset or model type.")
449
 
450
  def explore_data():
451
  st.header("Explore Stock Data")