Spaces:
Sleeping
Sleeping
import streamlit as st | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from datetime import date | |
from sklearn.model_selection import train_test_split | |
from prophet import Prophet | |
from prophet.plot import plot_plotly | |
from plotly import graph_objs as go | |
from prophet.make_holidays import make_holidays_df | |
from sklearn.metrics import mean_absolute_error, mean_squared_error | |
st.set_page_config(layout='wide', initial_sidebar_state='expanded') | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
st.title('ML Wall Street') | |
st.image('images/img.png') | |
START = "2021-01-01" | |
TODAY = date.today().strftime("%Y-%m-%d") | |
period = st.slider('Количество дней прогноза:', 1, 14, 14) | |
# @st.cache_data | |
def load_data(): | |
dji = yf.download('^DJI', START, TODAY) | |
dji.reset_index(inplace=True) | |
data_500 = yf.download('^GSPC', START, TODAY) | |
data_500.reset_index(inplace=True) | |
sse = yf.download('000001.SS', START, TODAY) | |
sse.reset_index(inplace=True) | |
imoex = yf.download('IMOEX.ME', START, TODAY) | |
imoex.reset_index(inplace=True) | |
return dji, data_500, sse, imoex | |
def plot_forecast(model, forecast, text): | |
fig = plot_plotly(model, forecast) | |
fig.update_layout(title_text=text, xaxis_rangeslider_visible=True, xaxis_title='', yaxis_title='') | |
st.plotly_chart(fig, use_container_width=True, height=400, range_slider_visible=True) | |
dji, data_500, sse, imoex = load_data() | |
latest_date = dji['Date'].iloc[-1].strftime('%Y-%m-%d') | |
st.markdown(f"<h3 style='text-align: center;'>Цены актуальны на последнюю дату закрытия торгов {latest_date}</h3>", unsafe_allow_html=True) | |
# # Добавляем кнопку обновления данных | |
# if st.button("Обновить данные", type="primary"): | |
# dji = yf.download('^DJI', START, TODAY) | |
# dji.reset_index(inplace=True) | |
# data_500 = yf.download('^GSPC', START, TODAY) | |
# data_500.reset_index(inplace=True) | |
# sse = yf.download('000001.SS', START, TODAY) | |
# sse.reset_index(inplace=True) | |
# imoex = yf.download('IMOEX.ME', START, TODAY) | |
# imoex.reset_index(inplace=True) | |
# st.success("Данные успешно обновлены!") | |
# if st.button("или Обновить данные", type="primary"): | |
# dji, data_500, sse, imoex = load_data() | |
def evaluate_trend_first_day(predicted_values, actual_values): | |
# Разница между первым днем прогноза и последним днем тестовых данных | |
forecast_diff_first_last = predicted_values[0] - actual_values[-1] | |
# Оценка тренда на первый день: рост, падение, стабильность | |
if forecast_diff_first_last > 0: | |
return "Тенденция на первый день: Рост" | |
elif forecast_diff_first_last < 0: | |
return "Тенденция на первый день: Падение" | |
else: | |
return "Тенденция на первый день: Стабильность" | |
def evaluate_trend_period(predicted_values): | |
# Разница между первым и последним значением прогноза | |
forecast_diff = predicted_values[-1] - predicted_values[0] | |
# Оценка тренда на весь период прогноза: рост, падение, стабильность | |
print("Разница между первым и последним значением прогноза:", forecast_diff) | |
if forecast_diff > 0: | |
return "Тенденция на период прогноза: Рост" | |
elif forecast_diff < 0: | |
return "Тенденция на период прогноза: Падение" | |
else: | |
return "Тенденция на период прогноза: Стабильность" | |
def index(ind, country_name, text1, text2): | |
data = ind | |
data = data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}) | |
# Разделение данных на обучающую и тестовую выборки | |
full_train_data3, full_test_data3 = train_test_split(data, test_size=period, shuffle=False) | |
# Удаляем временную зону из столбца ds | |
full_train_data3['ds'] = full_train_data3['ds'].dt.tz_localize(None) | |
full_test_data3['ds'] = full_test_data3['ds'].dt.tz_localize(None) | |
# Создаем модель Prophet | |
model = Prophet(interval_width=0.95) | |
# Добавляем стандартные праздничные дни для страны | |
model.add_country_holidays(country_name=country_name) | |
model.fit(full_train_data3) | |
future = model.make_future_dataframe(periods=full_test_data3.shape[0] + period, freq='D') | |
forecast3 = model.predict(future) | |
# Отрисовка графика | |
fig = go.Figure() | |
fig = plot_plotly(model, forecast3) | |
fig.add_trace(go.Scatter(x=full_test_data3['ds'], | |
y=full_test_data3['y'], | |
mode='markers', | |
marker=dict(color='orchid'), | |
name='Факт значения')) | |
fig.add_trace(go.Scatter(x=forecast3['ds'].iloc[-period:], y=forecast3['yhat'].iloc[-period:], mode='lines', name='Прогноз на +14 дней')) | |
fig.update_layout(title_text=text1, xaxis_rangeslider_visible=True, xaxis_title='', yaxis_title='') | |
fig.update_traces(showlegend=True) | |
st.plotly_chart(fig, use_container_width=True, range_slider_visible=True) | |
# Расчет метрик на тестовой выборке | |
actual_values_test = full_test_data3['y'].values | |
predicted_values_test = forecast3['yhat'].iloc[-period:].values | |
mape_test = np.mean(np.abs((actual_values_test - predicted_values_test) / actual_values_test)) * 100 | |
rmse_test = np.sqrt(mean_squared_error(actual_values_test, predicted_values_test)) | |
check = st.checkbox(text2) | |
if check: | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.write("**Информация.** \ | |
Прогноз сделан по нескольким дням прошлого и нескольким дням будущего периода.") | |
st.markdown("**Метрики для тестовой выборки:**") | |
st.write(f"MAPE: {mape_test:.2f}%") | |
st.write(f"RMSE: {rmse_test:.2f}") | |
# Оценка тренда на первый день | |
trend_evaluation_first_day = evaluate_trend_first_day(predicted_values_test, actual_values_test) | |
st.write(trend_evaluation_first_day) | |
# Оценка тренда на период прогноза | |
trend_evaluation_period = evaluate_trend_period(predicted_values_test) | |
st.write(trend_evaluation_period) | |
with col2: | |
forecast_results = pd.DataFrame({ | |
'Дата': forecast3['ds'].iloc[-period:].values, | |
'Прогноз': forecast3['yhat'].iloc[-period:].values.round(2) | |
}) | |
st.dataframe(forecast_results.set_index('Дата')) | |
text1_dji = f'График прогноза для {period} дней по индексу Dow Jones, USD 🇺🇸' | |
text2_dji = f"Результаты прогноза по Dow Jones Industrial Average" | |
index(dji, 'US', text1_dji, text2_dji) | |
text1_500 = f'График прогноза для {period} дней по индексу S&P 500, USD 🇺🇸' | |
text2_500 = f"Результаты прогноза по S&P 500" | |
index(data_500, 'US', text1_500, text2_500) | |
text1_sse = f'График прогноза для {period} дней по индексу SSE Composite, CNY 🇨🇳' | |
text2_sse = f"Результаты прогноза по SSE Composite Index" | |
index(sse, 'China', text1_sse, text2_sse) | |
text1_imoex = f'График прогноза для {period} дней по индексу MOEX Russia, RUB 🇷🇺' | |
text2_imoex = f"Результаты прогноза по MOEX Russia Index" | |
index(imoex, 'Russia', text1_imoex, text2_imoex) | |