Spaces:
Sleeping
Sleeping
Commit
·
6547b21
1
Parent(s):
1ddba8e
加入Prophet 雙選擇
Browse files
app.py
CHANGED
@@ -196,7 +196,7 @@ def update_stock(category, stock):
|
|
196 |
status_output: gr.update(value="")
|
197 |
}
|
198 |
|
199 |
-
def predict_stock(category, stock, stock_item, period, selected_features):
|
200 |
if not all([category, stock, stock_item]):
|
201 |
return gr.update(value=None), "請選擇產業類別、類股和股票"
|
202 |
|
@@ -208,55 +208,60 @@ def predict_stock(category, stock, stock_item, period, selected_features):
|
|
208 |
|
209 |
stock_items = get_stock_items(url)
|
210 |
stock_code = stock_items.get(stock_item, "")
|
211 |
-
|
212 |
if not stock_code:
|
213 |
return gr.update(value=None), "無法獲取股票代碼"
|
214 |
-
|
215 |
# 下載股票數據,根據用戶選擇的時間範圍
|
216 |
df = yf.download(stock_code, period=period)
|
217 |
if df.empty:
|
218 |
raise ValueError("無法獲取股票數據")
|
219 |
|
220 |
-
#
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
# 創建日期索引
|
235 |
-
dates = [datetime.now() + timedelta(days=i) for i in range(
|
236 |
date_labels = [d.strftime('%m/%d') for d in dates]
|
237 |
-
|
238 |
# 繪圖
|
239 |
fig, ax = plt.subplots(figsize=(14, 7))
|
240 |
-
|
241 |
-
labels = [f'預測{feature}' for feature in selected_features]
|
242 |
-
|
243 |
-
for i, (label, color) in enumerate(zip(labels, colors)):
|
244 |
-
ax.plot(date_labels, all_predictions[:, i], label=label,
|
245 |
-
marker='o', color=color, linewidth=2)
|
246 |
-
for j, value in enumerate(all_predictions[:, i]):
|
247 |
-
ax.annotate(f'{value:.2f}', (date_labels[j], value),
|
248 |
-
textcoords="offset points", xytext=(0,10),
|
249 |
-
ha='center', va='bottom')
|
250 |
-
|
251 |
ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
|
252 |
ax.set_xlabel('日期', labelpad=10)
|
253 |
ax.set_ylabel('股價', labelpad=10)
|
254 |
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
|
255 |
ax.grid(True, linestyle='--', alpha=0.7)
|
256 |
-
|
257 |
plt.tight_layout()
|
258 |
return gr.update(value=fig), "預測成功"
|
259 |
-
|
260 |
except Exception as e:
|
261 |
logging.error(f"預測過程發生錯誤: {str(e)}")
|
262 |
return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
|
@@ -296,28 +301,33 @@ with gr.Blocks() as demo:
|
|
296 |
label="選擇要用於預測的特徵",
|
297 |
value=['Open', 'Close']
|
298 |
)
|
|
|
|
|
|
|
|
|
|
|
299 |
predict_button = gr.Button("開始預測", variant="primary")
|
300 |
status_output = gr.Textbox(label="狀態", interactive=False)
|
301 |
|
302 |
with gr.Row():
|
303 |
stock_plot = gr.Plot(label="股價預測圖")
|
304 |
-
|
305 |
# 事件綁定
|
306 |
category_dropdown.change(
|
307 |
update_category,
|
308 |
inputs=[category_dropdown],
|
309 |
outputs=[stock_dropdown, stock_item_dropdown, stock_plot, status_output]
|
310 |
)
|
311 |
-
|
312 |
stock_dropdown.change(
|
313 |
update_stock,
|
314 |
inputs=[category_dropdown, stock_dropdown],
|
315 |
outputs=[stock_item_dropdown, stock_plot, status_output]
|
316 |
)
|
317 |
-
|
318 |
predict_button.click(
|
319 |
predict_stock,
|
320 |
-
inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkbox],
|
321 |
outputs=[stock_plot, status_output]
|
322 |
)
|
323 |
|
|
|
196 |
status_output: gr.update(value="")
|
197 |
}
|
198 |
|
199 |
+
def predict_stock(category, stock, stock_item, period, selected_features, model_choice):
|
200 |
if not all([category, stock, stock_item]):
|
201 |
return gr.update(value=None), "請選擇產業類別、類股和股票"
|
202 |
|
|
|
208 |
|
209 |
stock_items = get_stock_items(url)
|
210 |
stock_code = stock_items.get(stock_item, "")
|
211 |
+
|
212 |
if not stock_code:
|
213 |
return gr.update(value=None), "無法獲取股票代碼"
|
214 |
+
|
215 |
# 下載股票數據,根據用戶選擇的時間範圍
|
216 |
df = yf.download(stock_code, period=period)
|
217 |
if df.empty:
|
218 |
raise ValueError("無法獲取股票數據")
|
219 |
|
220 |
+
# 根據模型選擇進行預測
|
221 |
+
if model_choice == "LSTM":
|
222 |
+
predictor = StockPredictor()
|
223 |
+
predictor.train(df, selected_features)
|
224 |
+
last_data = predictor.scaler.transform(df[selected_features].iloc[-1:].values)
|
225 |
+
predictions = predictor.predict(last_data[0], 5)
|
226 |
+
|
227 |
+
# 反轉預測結果
|
228 |
+
last_original = df[selected_features].iloc[-1].values
|
229 |
+
predictions_original = predictor.scaler.inverse_transform(
|
230 |
+
np.vstack([last_data, predictions])
|
231 |
+
)
|
232 |
+
all_predictions = np.vstack([last_original, predictions_original[1:]])
|
233 |
|
234 |
+
elif model_choice == "Prophet":
|
235 |
+
from prophet import Prophet
|
236 |
+
prophet_df = df.reset_index()[['Date', 'Close']]
|
237 |
+
prophet_df.rename(columns={'Date': 'ds', 'Close': 'y'}, inplace=True)
|
238 |
+
|
239 |
+
model = Prophet()
|
240 |
+
model.fit(prophet_df)
|
241 |
+
|
242 |
+
future = model.make_future_dataframe(periods=5)
|
243 |
+
forecast = model.predict(future)
|
244 |
+
all_predictions = forecast[['ds', 'yhat']].tail(6).values
|
245 |
+
|
246 |
+
else:
|
247 |
+
return gr.update(value=None), "未知的模型選擇"
|
248 |
+
|
249 |
# 創建日期索引
|
250 |
+
dates = [datetime.now() + timedelta(days=i) for i in range(len(all_predictions))]
|
251 |
date_labels = [d.strftime('%m/%d') for d in dates]
|
252 |
+
|
253 |
# 繪圖
|
254 |
fig, ax = plt.subplots(figsize=(14, 7))
|
255 |
+
ax.plot(date_labels, all_predictions, label="預測股價", marker='o', color='#FF9999', linewidth=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
|
257 |
ax.set_xlabel('日期', labelpad=10)
|
258 |
ax.set_ylabel('股價', labelpad=10)
|
259 |
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
|
260 |
ax.grid(True, linestyle='--', alpha=0.7)
|
261 |
+
|
262 |
plt.tight_layout()
|
263 |
return gr.update(value=fig), "預測成功"
|
264 |
+
|
265 |
except Exception as e:
|
266 |
logging.error(f"預測過程發生錯誤: {str(e)}")
|
267 |
return gr.update(value=None), f"預測過程發生錯誤: {str(e)}"
|
|
|
301 |
label="選擇要用於預測的特徵",
|
302 |
value=['Open', 'Close']
|
303 |
)
|
304 |
+
model_dropdown = gr.Dropdown(
|
305 |
+
choices=["LSTM", "Prophet"],
|
306 |
+
label="選擇預測模型",
|
307 |
+
value="LSTM"
|
308 |
+
)
|
309 |
predict_button = gr.Button("開始預測", variant="primary")
|
310 |
status_output = gr.Textbox(label="狀態", interactive=False)
|
311 |
|
312 |
with gr.Row():
|
313 |
stock_plot = gr.Plot(label="股價預測圖")
|
314 |
+
|
315 |
# 事件綁定
|
316 |
category_dropdown.change(
|
317 |
update_category,
|
318 |
inputs=[category_dropdown],
|
319 |
outputs=[stock_dropdown, stock_item_dropdown, stock_plot, status_output]
|
320 |
)
|
321 |
+
|
322 |
stock_dropdown.change(
|
323 |
update_stock,
|
324 |
inputs=[category_dropdown, stock_dropdown],
|
325 |
outputs=[stock_item_dropdown, stock_plot, status_output]
|
326 |
)
|
327 |
+
|
328 |
predict_button.click(
|
329 |
predict_stock,
|
330 |
+
inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkbox, model_dropdown],
|
331 |
outputs=[stock_plot, status_output]
|
332 |
)
|
333 |
|