MacDash commited on
Commit
c63df78
·
verified ·
1 Parent(s): 760a820

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +265 -4
app.py CHANGED
@@ -85,6 +85,7 @@ def fetch_okx_symbols():
85
  symbols = [item["instId"] for item in data if item.get("instType") == "SPOT"]
86
  return ["BTC-USDT"] + symbols if symbols else ["BTC-USDT"]
87
  except Exception as e:
 
88
  return ["BTC-USDT"]
89
 
90
  # Fetch historical candle data from OKX API
@@ -94,8 +95,16 @@ def fetch_okx_candles(symbol, timeframe="1H", total=2000):
94
 
95
  for _ in range(calls_needed):
96
  params = {"instId": symbol, "bar": timeframe, "limit": 300}
97
- resp = requests.get(OKX_CANDLE_ENDPOINT, params=params)
98
- data = resp.json().get("data", [])
 
 
 
 
 
 
 
 
99
 
100
  if not data:
101
  break
@@ -115,5 +124,257 @@ def fetch_okx_candles(symbol, timeframe="1H", total=2000):
115
 
116
  df_all = pd.concat(all_data)
117
 
118
- # Convert timestamps to datetime and calculate indicators
119
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  symbols = [item["instId"] for item in data if item.get("instType") == "SPOT"]
86
  return ["BTC-USDT"] + symbols if symbols else ["BTC-USDT"]
87
  except Exception as e:
88
+ print(f"Error fetching symbols: {e}")
89
  return ["BTC-USDT"]
90
 
91
  # Fetch historical candle data from OKX API
 
95
 
96
  for _ in range(calls_needed):
97
  params = {"instId": symbol, "bar": timeframe, "limit": 300}
98
+ try:
99
+ resp = requests.get(OKX_CANDLE_ENDPOINT, params=params)
100
+ resp.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
101
+ data = resp.json().get("data", [])
102
+ except requests.exceptions.RequestException as e:
103
+ print(f"Error fetching candles: {e}")
104
+ return pd.DataFrame()
105
+ except (ValueError, KeyError) as e:
106
+ print(f"Error parsing candle data: {e}")
107
+ return pd.DataFrame()
108
 
109
  if not data:
110
  break
 
124
 
125
  df_all = pd.concat(all_data)
126
 
127
+ # Convert timestamps to datetime and calculate indicators
128
+ df_all["timestamp"] = pd.to_datetime(df_all["timestamp"], unit="ms")
129
+ numeric_cols = ["open", "high", "low", "close"]
130
+ df_all[numeric_cols] = df_all[numeric_cols].astype(float)
131
+ df_all = calculate_technical_indicators(df_all)
132
+
133
+ return df_all
134
+
135
+ # Prepare data for Prophet forecasting
136
+ def prepare_data_for_prophet(df):
137
+ if df.empty:
138
+ return pd.DataFrame(columns=["ds", "y"])
139
+ df_prophet = df.rename(columns={"timestamp": "ds", "close": "y"})
140
+ return df_prophet[["ds", "y"]]
141
+
142
+ # Perform forecasting using Prophet
143
+ def prophet_forecast(df_prophet, periods=10, freq="h", daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=False, seasonality_mode="additive", changepoint_prior_scale=0.05):
144
+ if df_prophet.empty:
145
+ return pd.DataFrame(), "No data for Prophet."
146
+
147
+ try:
148
+ model = Prophet(
149
+ daily_seasonality=daily_seasonality,
150
+ weekly_seasonality=weekly_seasonality,
151
+ yearly_seasonality=yearly_seasonality,
152
+ seasonality_mode=seasonality_mode,
153
+ changepoint_prior_scale=changepoint_prior_scale,
154
+ )
155
+ model.fit(df_prophet)
156
+ future = model.make_future_dataframe(periods=periods, freq=freq)
157
+ forecast = model.predict(future)
158
+ return forecast, ""
159
+ except Exception as e:
160
+ return pd.DataFrame(), f"Forecast error: {e}"
161
+
162
+ # Wrapper function for forecasting
163
+ def prophet_wrapper(df_prophet, forecast_steps, freq, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale):
164
+ if len(df_prophet) < 10:
165
+ return pd.DataFrame(), "Not enough data for forecasting (need >=10 rows)."
166
+
167
+ full_forecast, err = prophet_forecast(
168
+ df_prophet,
169
+ periods=forecast_steps,
170
+ freq=freq,
171
+ daily_seasonality=daily_seasonality,
172
+ weekly_seasonality=weekly_seasonality,
173
+ yearly_seasonality=yearly_seasonality,
174
+ seasonality_mode=seasonality_mode,
175
+ changepoint_prior_scale=changepoint_prior_scale,
176
+ )
177
+ if err:
178
+ return pd.DataFrame(), err
179
+
180
+ future_only = full_forecast.loc[len(df_prophet):, ["ds", "yhat", "yhat_lower", "yhat_upper"]]
181
+ return future_only, ""
182
+
183
+ # Create forecast plot
184
+ def create_forecast_plot(forecast_df):
185
+ if forecast_df.empty:
186
+ return go.Figure()
187
+
188
+ fig = go.Figure()
189
+ fig.add_trace(go.Scatter(
190
+ x=forecast_df["ds"],
191
+ y=forecast_df["yhat"],
192
+ mode="lines",
193
+ name="Forecast",
194
+ line=dict(color="blue", width=2)
195
+ ))
196
+
197
+ fig.add_trace(go.Scatter(
198
+ x=forecast_df["ds"],
199
+ y=forecast_df["yhat_lower"],
200
+ fill=None,
201
+ mode="lines",
202
+ line=dict(width=0),
203
+ showlegend=True,
204
+ name="Lower Bound"
205
+ ))
206
+
207
+ fig.add_trace(go.Scatter(
208
+ x=forecast_df["ds"],
209
+ y=forecast_df["yhat_upper"],
210
+ fill="tonexty",
211
+ mode="lines",
212
+ line=dict(width=0),
213
+ name="Upper Bound"
214
+ ))
215
+
216
+ fig.update_layout(
217
+ title="Price Forecast",
218
+ xaxis_title="Time",
219
+ yaxis_title="Price",
220
+ hovermode="x unified",
221
+ template="plotly_white",
222
+ )
223
+ return fig
224
+
225
+ # Function to display forecast and technical analysis charts
226
+ def display_forecast(symbol, timeframe, forecast_steps, total_candles, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale):
227
+ df_raw, forecast_df, error = predict(
228
+ symbol=symbol,
229
+ timeframe=timeframe,
230
+ forecast_steps=forecast_steps,
231
+ total_candles=total_candles,
232
+ daily_seasonality=daily_seasonality,
233
+ weekly_seasonality=weekly_seasonality,
234
+ yearly_seasonality=yearly_seasonality,
235
+ seasonality_mode=seasonality_mode,
236
+ changepoint_prior_scale=changepoint_prior_scale
237
+ )
238
+
239
+ if error:
240
+ return None, None, None, None, pd.DataFrame() # Return empty dataframe for forecast_df
241
+
242
+ forecast_plot = create_forecast_plot(forecast_df)
243
+ tech_plot, rsi_plot, macd_plot = create_technical_charts(df_raw)
244
+
245
+ # Prepare forecast data for the Dataframe output
246
+ forecast_df_display = forecast_df.loc[:, ["ds", "yhat", "yhat_lower", "yhat_upper"]].copy()
247
+ forecast_df_display.rename(columns={"ds": "Date", "yhat": "Forecast", "yhat_lower": "Lower Bound", "yhat_upper": "Upper Bound"}, inplace=True)
248
+
249
+ return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df_display
250
+
251
+ # Main prediction function
252
+ def predict(symbol, timeframe, forecast_steps, total_candles, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale):
253
+ okx_bar = TIMEFRAME_MAPPING.get(timeframe, "1H")
254
+ df_raw = fetch_okx_candles(symbol=symbol, timeframe=okx_bar, total=total_candles)
255
+
256
+ if df_raw.empty:
257
+ return pd.DataFrame(), pd.DataFrame(), "No data fetched."
258
+
259
+ df_prophet = prepare_data_for_prophet(df_raw)
260
+ freq = "h" if "h" in timeframe.lower() else "d"
261
+
262
+ future_df, err2 = prophet_wrapper(
263
+ df_prophet,
264
+ forecast_steps,
265
+ freq,
266
+ daily_seasonality,
267
+ weekly_seasonality,
268
+ yearly_seasonality,
269
+ seasonality_mode,
270
+ changepoint_prior_scale,
271
+ )
272
+
273
+ if err2:
274
+ return pd.DataFrame(), pd.DataFrame(), err2
275
+
276
+ return df_raw, future_df, ""
277
+
278
+
279
+ # Main Gradio app setup
280
+ def main():
281
+ symbols = fetch_okx_symbols()
282
+
283
+ with gr.Blocks(theme=gr.themes.Base()) as demo:
284
+ # Header
285
+ with gr.Row():
286
+ gr.Markdown("# CryptoVision")
287
+
288
+ # Market Selection and Forecast Parameters
289
+ with gr.Row():
290
+ with gr.Column(scale=1):
291
+ gr.Markdown("### Market Selection")
292
+ symbol_dd = gr.Dropdown(
293
+ label="Trading Pair",
294
+ choices=symbols,
295
+ value="BTC-USDT"
296
+ )
297
+ timeframe_dd = gr.Dropdown(
298
+ label="Timeframe",
299
+ choices=list(TIMEFRAME_MAPPING.keys()),
300
+ value="1h"
301
+ )
302
+ with gr.Column(scale=1):
303
+ gr.Markdown("### Forecast Parameters")
304
+ forecast_steps_slider = gr.Slider(
305
+ label="Forecast Steps",
306
+ minimum=1,
307
+ maximum=100,
308
+ value=24,
309
+ step=1
310
+ )
311
+ total_candles_slider = gr.Slider(
312
+ label="Historical Candles",
313
+ minimum=300,
314
+ maximum=3000,
315
+ value=2000,
316
+ step=100
317
+ )
318
+
319
+ # Advanced Settings
320
+ with gr.Row():
321
+ with gr.Column():
322
+ gr.Markdown("### Advanced Settings")
323
+ daily_box = gr.Checkbox(label="Daily Seasonality", value=True)
324
+ weekly_box = gr.Checkbox(label="Weekly Seasonality", value=True)
325
+ yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False)
326
+ seasonality_mode_dd = gr.Dropdown(
327
+ label="Seasonality Mode",
328
+ choices=["additive", "multiplicative"],
329
+ value="additive"
330
+ )
331
+ changepoint_scale_slider = gr.Slider(
332
+ label="Changepoint Prior Scale",
333
+ minimum=0.01,
334
+ maximum=1.0,
335
+ step=0.01,
336
+ value=0.05
337
+ )
338
+
339
+ # Generate Forecast Button
340
+ forecast_btn = gr.Button("Generate Forecast", variant="primary", size="lg")
341
+
342
+ # Output Plots
343
+ with gr.Row():
344
+ forecast_plot = gr.Plot(label="Price Forecast")
345
+
346
+ with gr.Row():
347
+ tech_plot = gr.Plot(label="Technical Analysis")
348
+ rsi_plot = gr.Plot(label="RSI Indicator")
349
+
350
+ with gr.Row():
351
+ macd_plot = gr.Plot(label="MACD")
352
+
353
+ # Output Data Table
354
+ forecast_df = gr.Dataframe(
355
+ label="Forecast Data",
356
+ headers=["Date", "Forecast", "Lower Bound", "Upper Bound"]
357
+ )
358
+
359
+ # Button click functionality
360
+ forecast_btn.click(
361
+ fn=display_forecast,
362
+ inputs=[
363
+ symbol_dd,
364
+ timeframe_dd,
365
+ forecast_steps_slider,
366
+ total_candles_slider,
367
+ daily_box,
368
+ weekly_box,
369
+ yearly_box,
370
+ seasonality_mode_dd,
371
+ changepoint_scale_slider,
372
+ ],
373
+ outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df]
374
+ )
375
+
376
+ return demo
377
+
378
+ if __name__ == "__main__":
379
+ app = main()
380
+ app.launch()