Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -8,35 +8,97 @@ from datetime import datetime, timedelta
|
|
8 |
from PIL import Image
|
9 |
from plotly import graph_objs as go
|
10 |
from datetime import date
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
st.set_page_config(layout='wide', initial_sidebar_state='expanded')
|
13 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
14 |
st.title('ML Wall Street')
|
15 |
st.image('images/img.png')
|
16 |
|
17 |
-
# @st.cache_data
|
18 |
# Функция для получения данных о ценах акций
|
|
|
19 |
def get_stock_data():
|
20 |
dow_tickers = ['UNH', 'MSFT', 'GS', 'HD', 'AMGN', 'MCD', 'CAT', 'CRM', 'V', 'BA', 'HON', 'TRV', 'AAPL', 'AXP', 'JPM', 'IBM', 'JNJ', 'WMT', 'PG', 'CVX', 'MRK', 'MMM', 'NKE', 'DIS', 'KO', 'DOW', 'CSCO', 'INTC', 'VZ', 'WBA']
|
21 |
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y-%m-%d')
|
22 |
end_date = datetime.now().strftime('%Y-%m-%d')
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
return dow_data
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
data = get_stock_data()
|
27 |
latest_date = data.index[-1].strftime('%Y-%m-%d')
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
# latest_date = data.index[-1].strftime('%Y-%m-%d')
|
36 |
-
# data = data.loc[latest_date, 'Close'].reset_index()
|
37 |
-
# data.columns = ['Ticker', 'Close']
|
38 |
-
# data['Close'] = data['Close'].round(2)
|
39 |
-
# st.success("Данные успешно обновлены!")
|
40 |
|
41 |
st.markdown(f"<h3 style='text-align: center;'>Цены актуальны на последнюю дату закрытия торгов {latest_date}</h3>", unsafe_allow_html=True)
|
42 |
|
@@ -71,14 +133,6 @@ with col2:
|
|
71 |
st.dataframe(data, height=1088, column_config={"Logo": image_column, "Ticker":ticker_column, 'Close':price_column})
|
72 |
|
73 |
with col1:
|
74 |
-
START = "1920-01-01"
|
75 |
-
TODAY = date.today().strftime("%Y-%m-%d")
|
76 |
-
# @st.cache_data
|
77 |
-
def load_data(ticker):
|
78 |
-
data = yf.download(ticker, START, TODAY)
|
79 |
-
data.reset_index(inplace=True)
|
80 |
-
return data
|
81 |
-
|
82 |
def plot_raw_data(data, text):
|
83 |
fig = go.Figure()
|
84 |
fig.add_trace(go.Scatter(x=data['Date'], y=data['Close'], name="Цена закрытия"))
|
|
|
8 |
from PIL import Image
|
9 |
from plotly import graph_objs as go
|
10 |
from datetime import date
|
11 |
+
from model.lstm_model import BiLSTM
|
12 |
+
|
13 |
+
# Загрузка весов модели (выполняется только при первом запуске)
|
14 |
+
@st.cache(allow_output_mutation=True)
|
15 |
+
def load_model_weights():
|
16 |
+
model = BiLSTM(input_size, hidden_size, num_layers, output_size)
|
17 |
+
model.load_state_dict(torch.load('model/model_weights.pth'))
|
18 |
+
model.eval()
|
19 |
+
return model
|
20 |
+
|
21 |
|
22 |
st.set_page_config(layout='wide', initial_sidebar_state='expanded')
|
23 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
24 |
st.title('ML Wall Street')
|
25 |
st.image('images/img.png')
|
26 |
|
|
|
27 |
# Функция для получения данных о ценах акций
|
28 |
+
@st.cache(allow_output_mutation=True)
|
29 |
def get_stock_data():
|
30 |
dow_tickers = ['UNH', 'MSFT', 'GS', 'HD', 'AMGN', 'MCD', 'CAT', 'CRM', 'V', 'BA', 'HON', 'TRV', 'AAPL', 'AXP', 'JPM', 'IBM', 'JNJ', 'WMT', 'PG', 'CVX', 'MRK', 'MMM', 'NKE', 'DIS', 'KO', 'DOW', 'CSCO', 'INTC', 'VZ', 'WBA']
|
31 |
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y-%m-%d')
|
32 |
end_date = datetime.now().strftime('%Y-%m-%d')
|
33 |
+
|
34 |
+
# Проверка, прошло ли более 24 часов с последнего обновления данных
|
35 |
+
if 'last_stock_update' not in st.session_state or (datetime.now() - st.session_state.last_stock_update).total_seconds() > 43200:
|
36 |
+
dow_data = yf.download(dow_tickers, start=start_date, end=end_date)
|
37 |
+
|
38 |
+
# Сохранение данных в кэш и в сессионном состоянии
|
39 |
+
st.cache(f"{start_date}_{end_date}")(dow_data)
|
40 |
+
st.session_state.stock_data = dow_data
|
41 |
+
st.session_state.last_stock_update = datetime.now()
|
42 |
+
else:
|
43 |
+
# Если данные уже в сессионном состоянии, возвращаем их
|
44 |
+
dow_data = st.session_state.stock_data
|
45 |
+
|
46 |
return dow_data
|
47 |
|
48 |
+
# Функция для получения данных по индексу
|
49 |
+
@st.cache(allow_output_mutation=True)
|
50 |
+
def load_data(index_symbol):
|
51 |
+
start_date = "2021-01-01"
|
52 |
+
end_date = date.today().strftime("%Y-%m-%d")
|
53 |
+
|
54 |
+
# Проверка, прошло ли более 24 часов с последнего обновления данных
|
55 |
+
last_update_key = f'last_{index_symbol.lower()}_update'
|
56 |
+
data_key = f'{index_symbol.lower()}_data'
|
57 |
+
|
58 |
+
if last_update_key not in st.session_state or (datetime.now() - st.session_state[last_update_key]).total_seconds() > 86400:
|
59 |
+
df = yf.download(index_symbol, start=start_date, end=end_date)
|
60 |
+
df.reset_index(inplace=True)
|
61 |
+
|
62 |
+
# Сохранение данных в кэш и в сессионном состоянии
|
63 |
+
st.cache(f"{index_symbol.lower()}_{start_date}_{end_date}")(df)
|
64 |
+
st.session_state[data_key] = df
|
65 |
+
st.session_state[last_update_key] = datetime.now()
|
66 |
+
else:
|
67 |
+
# Если данные уже в сессионном состоянии, возвращаем их
|
68 |
+
df = st.session_state[data_key]
|
69 |
+
|
70 |
+
return df
|
71 |
+
|
72 |
+
# Проверка, прошло ли более 12 часов с последнего обновления данных для индекса
|
73 |
+
if last_update_key not in st.session_state or (datetime.now() - st.session_state[last_update_key]).total_seconds() > 43200:
|
74 |
+
df = yf.download(index_symbol, start=start_date, end=end_date)
|
75 |
+
df.reset_index(inplace=True)
|
76 |
+
|
77 |
+
# Сохранение данных в кэш и в сессионном состоянии
|
78 |
+
st.cache(f"{index_symbol.lower()}_{start_date}_{end_date}")(df)
|
79 |
+
st.session_state[data_key] = df
|
80 |
+
st.session_state[last_update_key] = datetime.now()
|
81 |
+
else:
|
82 |
+
# Если данные уже в сессионном состоянии, возвращаем их
|
83 |
+
df = st.session_state[data_key]
|
84 |
+
|
85 |
+
# Пример использования для разных индексов
|
86 |
+
bitcoin_data = load_data('BTC-USD')
|
87 |
+
sse_data = load_data('000001.SS')
|
88 |
+
moex_data = load_data('IMOEX.ME')
|
89 |
+
dji_data = load_data('^DJI')
|
90 |
+
sp500_data = load_data('^GSPC')
|
91 |
+
|
92 |
+
# Получение данных о ценах акций
|
93 |
data = get_stock_data()
|
94 |
latest_date = data.index[-1].strftime('%Y-%m-%d')
|
95 |
+
|
96 |
+
model_weights = load_model_weights()
|
97 |
+
|
98 |
+
# Сохранение весов модели в ��ессионном состоянии
|
99 |
+
if 'model_weights' not in st.session_state:
|
100 |
+
st.session_state.model_weights = model_weights
|
101 |
+
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
st.markdown(f"<h3 style='text-align: center;'>Цены актуальны на последнюю дату закрытия торгов {latest_date}</h3>", unsafe_allow_html=True)
|
104 |
|
|
|
133 |
st.dataframe(data, height=1088, column_config={"Logo": image_column, "Ticker":ticker_column, 'Close':price_column})
|
134 |
|
135 |
with col1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
def plot_raw_data(data, text):
|
137 |
fig = go.Figure()
|
138 |
fig.add_trace(go.Scatter(x=data['Date'], y=data['Close'], name="Цена закрытия"))
|