Update app.py
Browse files
app.py
CHANGED
@@ -36,7 +36,7 @@ COMPANIES = [
|
|
36 |
]
|
37 |
|
38 |
class StockPredictor:
|
39 |
-
def __init__(self, data, model_type):
|
40 |
self.data = data
|
41 |
self.model_type = model_type
|
42 |
self.model = None
|
@@ -349,7 +349,7 @@ def main():
|
|
349 |
predict_stock_prices()
|
350 |
|
351 |
def test_model():
|
352 |
-
st.header("Test
|
353 |
|
354 |
col1, col2 = st.columns(2)
|
355 |
|
@@ -357,6 +357,9 @@ def test_model():
|
|
357 |
company = st.selectbox("Select Company", [company for company, _ in COMPANIES])
|
358 |
test_split = st.slider("Test Data Split", 0.1, 0.5, 0.2, 0.05)
|
359 |
|
|
|
|
|
|
|
360 |
if st.button("Train and Test Model"):
|
361 |
with st.spinner("Fetching data and training model..."):
|
362 |
company_name, ticker = next((name, symbol) for name, symbol in COMPANIES if name == company)
|
@@ -375,7 +378,7 @@ def test_model():
|
|
375 |
train_data = data.iloc[:split_index]
|
376 |
test_data = data.iloc[split_index:]
|
377 |
|
378 |
-
predictor = StockPredictor(train_data) # Updated:
|
379 |
predictor.preprocess_data()
|
380 |
if predictor.train_model():
|
381 |
test_pred = predictor.predict(days=len(test_data))
|
@@ -392,7 +395,7 @@ def test_model():
|
|
392 |
plot = create_test_plot(predictor.data, test_data, test_pred, company_name)
|
393 |
st.plotly_chart(plot, use_container_width=True)
|
394 |
else:
|
395 |
-
st.error("Failed to train the
|
396 |
|
397 |
def predict_stock_prices():
|
398 |
st.header("Predict Stock Prices")
|
@@ -403,6 +406,9 @@ def predict_stock_prices():
|
|
403 |
company = st.selectbox("Select Company", [company for company, _ in COMPANIES])
|
404 |
days_to_predict = st.slider("Days to Predict", 1, 365, 30)
|
405 |
|
|
|
|
|
|
|
406 |
if st.button("Predict Stock Prices"):
|
407 |
with st.spinner("Fetching data and making predictions..."):
|
408 |
company_name, ticker = next((name, symbol) for name, symbol in COMPANIES if name == company)
|
@@ -417,7 +423,7 @@ def predict_stock_prices():
|
|
417 |
|
418 |
st.markdown(get_table_download_link(data), unsafe_allow_html=True)
|
419 |
|
420 |
-
predictor = StockPredictor(data) # Updated:
|
421 |
predictor.preprocess_data()
|
422 |
if predictor.train_model():
|
423 |
predictions = predictor.predict(days=days_to_predict)
|
@@ -439,7 +445,7 @@ def predict_stock_prices():
|
|
439 |
for item in news:
|
440 |
st.markdown(f"[{item['title']}]({item['link']}) ({item['pubDate']})")
|
441 |
else:
|
442 |
-
st.error("Failed to train the
|
443 |
|
444 |
def explore_data():
|
445 |
st.header("Explore Stock Data")
|
|
|
36 |
]
|
37 |
|
38 |
class StockPredictor:
|
39 |
+
def __init__(self, data, model_type='Prophet'):
|
40 |
self.data = data
|
41 |
self.model_type = model_type
|
42 |
self.model = None
|
|
|
349 |
predict_stock_prices()
|
350 |
|
351 |
def test_model():
|
352 |
+
st.header("Test Stock Prediction Model")
|
353 |
|
354 |
col1, col2 = st.columns(2)
|
355 |
|
|
|
357 |
company = st.selectbox("Select Company", [company for company, _ in COMPANIES])
|
358 |
test_split = st.slider("Test Data Split", 0.1, 0.5, 0.2, 0.05)
|
359 |
|
360 |
+
with col2:
|
361 |
+
model_type = st.selectbox("Select Model Type", ['Prophet', 'LSTM', 'SARIMA', 'XGBoost', 'RandomForest'])
|
362 |
+
|
363 |
if st.button("Train and Test Model"):
|
364 |
with st.spinner("Fetching data and training model..."):
|
365 |
company_name, ticker = next((name, symbol) for name, symbol in COMPANIES if name == company)
|
|
|
378 |
train_data = data.iloc[:split_index]
|
379 |
test_data = data.iloc[split_index:]
|
380 |
|
381 |
+
predictor = StockPredictor(train_data, model_type) # Updated: added model_type argument
|
382 |
predictor.preprocess_data()
|
383 |
if predictor.train_model():
|
384 |
test_pred = predictor.predict(days=len(test_data))
|
|
|
395 |
plot = create_test_plot(predictor.data, test_data, test_pred, company_name)
|
396 |
st.plotly_chart(plot, use_container_width=True)
|
397 |
else:
|
398 |
+
st.error(f"Failed to train the {model_type} model. Please try a different dataset or model type.")
|
399 |
|
400 |
def predict_stock_prices():
|
401 |
st.header("Predict Stock Prices")
|
|
|
406 |
company = st.selectbox("Select Company", [company for company, _ in COMPANIES])
|
407 |
days_to_predict = st.slider("Days to Predict", 1, 365, 30)
|
408 |
|
409 |
+
with col2:
|
410 |
+
model_type = st.selectbox("Select Model Type", ['Prophet', 'LSTM', 'SARIMA', 'XGBoost', 'RandomForest'])
|
411 |
+
|
412 |
if st.button("Predict Stock Prices"):
|
413 |
with st.spinner("Fetching data and making predictions..."):
|
414 |
company_name, ticker = next((name, symbol) for name, symbol in COMPANIES if name == company)
|
|
|
423 |
|
424 |
st.markdown(get_table_download_link(data), unsafe_allow_html=True)
|
425 |
|
426 |
+
predictor = StockPredictor(data, model_type) # Updated: added model_type argument
|
427 |
predictor.preprocess_data()
|
428 |
if predictor.train_model():
|
429 |
predictions = predictor.predict(days=days_to_predict)
|
|
|
445 |
for item in news:
|
446 |
st.markdown(f"[{item['title']}]({item['link']}) ({item['pubDate']})")
|
447 |
else:
|
448 |
+
st.error(f"Failed to train the {model_type} model. Please try a different dataset or model type.")
|
449 |
|
450 |
def explore_data():
|
451 |
st.header("Explore Stock Data")
|