Solar-Iz commited on
Commit
88589b6
·
verified ·
1 Parent(s): a0b4f78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -22
app.py CHANGED
@@ -8,35 +8,97 @@ from datetime import datetime, timedelta
8
  from PIL import Image
9
  from plotly import graph_objs as go
10
  from datetime import date
 
 
 
 
 
 
 
 
 
 
11
 
12
  st.set_page_config(layout='wide', initial_sidebar_state='expanded')
13
  st.set_option('deprecation.showPyplotGlobalUse', False)
14
  st.title('ML Wall Street')
15
  st.image('images/img.png')
16
 
17
- # @st.cache_data
18
  # Функция для получения данных о ценах акций
 
19
  def get_stock_data():
20
  dow_tickers = ['UNH', 'MSFT', 'GS', 'HD', 'AMGN', 'MCD', 'CAT', 'CRM', 'V', 'BA', 'HON', 'TRV', 'AAPL', 'AXP', 'JPM', 'IBM', 'JNJ', 'WMT', 'PG', 'CVX', 'MRK', 'MMM', 'NKE', 'DIS', 'KO', 'DOW', 'CSCO', 'INTC', 'VZ', 'WBA']
21
  start_date = (datetime.now() - timedelta(days=365)).strftime('%Y-%m-%d')
22
  end_date = datetime.now().strftime('%Y-%m-%d')
23
- dow_data = yf.download(dow_tickers, start=start_date, end=end_date)
 
 
 
 
 
 
 
 
 
 
 
 
24
  return dow_data
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  data = get_stock_data()
27
  latest_date = data.index[-1].strftime('%Y-%m-%d')
28
- data = data.loc[latest_date, 'Close'].reset_index()
29
- data.columns = ['Ticker', 'Close']
30
- data['Close'] = data['Close'].round(2)
31
-
32
- # Добавляем кнопку обновления данных
33
- # if st.button("Обновить данные", type="primary"):
34
- # data = get_stock_data()
35
- # latest_date = data.index[-1].strftime('%Y-%m-%d')
36
- # data = data.loc[latest_date, 'Close'].reset_index()
37
- # data.columns = ['Ticker', 'Close']
38
- # data['Close'] = data['Close'].round(2)
39
- # st.success("Данные успешно обновлены!")
40
 
41
  st.markdown(f"<h3 style='text-align: center;'>Цены актуальны на последнюю дату закрытия торгов {latest_date}</h3>", unsafe_allow_html=True)
42
 
@@ -71,14 +133,6 @@ with col2:
71
  st.dataframe(data, height=1088, column_config={"Logo": image_column, "Ticker":ticker_column, 'Close':price_column})
72
 
73
  with col1:
74
- START = "1920-01-01"
75
- TODAY = date.today().strftime("%Y-%m-%d")
76
- # @st.cache_data
77
- def load_data(ticker):
78
- data = yf.download(ticker, START, TODAY)
79
- data.reset_index(inplace=True)
80
- return data
81
-
82
  def plot_raw_data(data, text):
83
  fig = go.Figure()
84
  fig.add_trace(go.Scatter(x=data['Date'], y=data['Close'], name="Цена закрытия"))
 
8
  from PIL import Image
9
  from plotly import graph_objs as go
10
  from datetime import date
11
+ from model.lstm_model import BiLSTM
12
+
13
+ # Загрузка весов модели (выполняется только при первом запуске)
14
+ @st.cache(allow_output_mutation=True)
15
+ def load_model_weights():
16
+ model = BiLSTM(input_size, hidden_size, num_layers, output_size)
17
+ model.load_state_dict(torch.load('model/model_weights.pth'))
18
+ model.eval()
19
+ return model
20
+
21
 
22
  st.set_page_config(layout='wide', initial_sidebar_state='expanded')
23
  st.set_option('deprecation.showPyplotGlobalUse', False)
24
  st.title('ML Wall Street')
25
  st.image('images/img.png')
26
 
 
27
  # Функция для получения данных о ценах акций
28
+ @st.cache(allow_output_mutation=True)
29
  def get_stock_data():
30
  dow_tickers = ['UNH', 'MSFT', 'GS', 'HD', 'AMGN', 'MCD', 'CAT', 'CRM', 'V', 'BA', 'HON', 'TRV', 'AAPL', 'AXP', 'JPM', 'IBM', 'JNJ', 'WMT', 'PG', 'CVX', 'MRK', 'MMM', 'NKE', 'DIS', 'KO', 'DOW', 'CSCO', 'INTC', 'VZ', 'WBA']
31
  start_date = (datetime.now() - timedelta(days=365)).strftime('%Y-%m-%d')
32
  end_date = datetime.now().strftime('%Y-%m-%d')
33
+
34
+ # Проверка, прошло ли более 24 часов с последнего обновления данных
35
+ if 'last_stock_update' not in st.session_state or (datetime.now() - st.session_state.last_stock_update).total_seconds() > 43200:
36
+ dow_data = yf.download(dow_tickers, start=start_date, end=end_date)
37
+
38
+ # Сохранение данных в кэш и в сессионном состоянии
39
+ st.cache(f"{start_date}_{end_date}")(dow_data)
40
+ st.session_state.stock_data = dow_data
41
+ st.session_state.last_stock_update = datetime.now()
42
+ else:
43
+ # Если данные уже в сессионном состоянии, возвращаем их
44
+ dow_data = st.session_state.stock_data
45
+
46
  return dow_data
47
 
48
+ # Функция для получения данных по индексу
49
+ @st.cache(allow_output_mutation=True)
50
+ def load_data(index_symbol):
51
+ start_date = "2021-01-01"
52
+ end_date = date.today().strftime("%Y-%m-%d")
53
+
54
+ # Проверка, прошло ли более 24 часов с последнего обновления данных
55
+ last_update_key = f'last_{index_symbol.lower()}_update'
56
+ data_key = f'{index_symbol.lower()}_data'
57
+
58
+ if last_update_key not in st.session_state or (datetime.now() - st.session_state[last_update_key]).total_seconds() > 86400:
59
+ df = yf.download(index_symbol, start=start_date, end=end_date)
60
+ df.reset_index(inplace=True)
61
+
62
+ # Сохранение данных в кэш и в сессионном состоянии
63
+ st.cache(f"{index_symbol.lower()}_{start_date}_{end_date}")(df)
64
+ st.session_state[data_key] = df
65
+ st.session_state[last_update_key] = datetime.now()
66
+ else:
67
+ # Если данные уже в сессионном состоянии, возвращаем их
68
+ df = st.session_state[data_key]
69
+
70
+ return df
71
+
72
+ # Проверка, прошло ли более 12 часов с последнего обновления данных для индекса
73
+ if last_update_key not in st.session_state or (datetime.now() - st.session_state[last_update_key]).total_seconds() > 43200:
74
+ df = yf.download(index_symbol, start=start_date, end=end_date)
75
+ df.reset_index(inplace=True)
76
+
77
+ # Сохранение данных в кэш и в сессионном состоянии
78
+ st.cache(f"{index_symbol.lower()}_{start_date}_{end_date}")(df)
79
+ st.session_state[data_key] = df
80
+ st.session_state[last_update_key] = datetime.now()
81
+ else:
82
+ # Если данные уже в сессионном состоянии, возвращаем их
83
+ df = st.session_state[data_key]
84
+
85
+ # Пример использования для разных индексов
86
+ bitcoin_data = load_data('BTC-USD')
87
+ sse_data = load_data('000001.SS')
88
+ moex_data = load_data('IMOEX.ME')
89
+ dji_data = load_data('^DJI')
90
+ sp500_data = load_data('^GSPC')
91
+
92
+ # Получение данных о ценах акций
93
  data = get_stock_data()
94
  latest_date = data.index[-1].strftime('%Y-%m-%d')
95
+
96
+ model_weights = load_model_weights()
97
+
98
+ # Сохранение весов модели в ��ессионном состоянии
99
+ if 'model_weights' not in st.session_state:
100
+ st.session_state.model_weights = model_weights
101
+
 
 
 
 
 
102
 
103
  st.markdown(f"<h3 style='text-align: center;'>Цены актуальны на последнюю дату закрытия торгов {latest_date}</h3>", unsafe_allow_html=True)
104
 
 
133
  st.dataframe(data, height=1088, column_config={"Logo": image_column, "Ticker":ticker_column, 'Close':price_column})
134
 
135
  with col1:
 
 
 
 
 
 
 
 
136
  def plot_raw_data(data, text):
137
  fig = go.Figure()
138
  fig.add_trace(go.Scatter(x=data['Date'], y=data['Close'], name="Цена закрытия"))