openfree commited on
Commit
b72b034
·
verified ·
1 Parent(s): 51ebd47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -178
app.py CHANGED
@@ -5,6 +5,7 @@ from prophet import Prophet
5
  import logging
6
  import plotly.graph_objs as go
7
  import math
 
8
 
9
  logging.basicConfig(level=logging.INFO)
10
 
@@ -15,7 +16,6 @@ logging.basicConfig(level=logging.INFO)
15
  OKX_TICKERS_ENDPOINT = "https://www.okx.com/api/v5/market/tickers?instType=SPOT"
16
  OKX_CANDLE_ENDPOINT = "https://www.okx.com/api/v5/market/candles"
17
 
18
- # Allowed bar intervals on OKX, maximum 300 records at a time
19
  TIMEFRAME_MAPPING = {
20
  "1m": "1m",
21
  "5m": "5m",
@@ -30,8 +30,59 @@ TIMEFRAME_MAPPING = {
30
  "1w": "1W",
31
  }
32
 
33
- ########################################
34
- # Functions to fetch data from OKX
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ########################################
36
 
37
  def fetch_okx_symbols():
@@ -46,43 +97,37 @@ def fetch_okx_symbols():
46
 
47
  if json_data.get("code") != "0":
48
  logging.error(f"Non-zero code returned: {json_data}")
49
- return ["Error: Could not fetch OKX symbols"]
50
 
51
  data = json_data.get("data", [])
52
  symbols = [item["instId"] for item in data if item.get("instType") == "SPOT"]
53
  if not symbols:
54
- logging.warning("No spot symbols found.")
55
- return ["Error: No spot symbols found."]
56
 
 
 
 
 
 
57
  logging.info(f"Fetched {len(symbols)} OKX spot symbols.")
58
- return sorted(symbols)
59
 
60
  except Exception as e:
61
  logging.error(f"Error fetching OKX symbols: {e}")
62
- return [f"Error: {str(e)}"]
63
-
64
 
65
  def fetch_okx_candles_chunk(symbol, timeframe, limit=300, after=None, before=None):
66
- """
67
- Fetch up to `limit` candles (max 300) for the given symbol/timeframe.
68
- Optionally use `after` or `before` to page through older or newer data.
69
-
70
- OKX returns newest data first. The result here is also newest first.
71
- We'll reorder or combine them later as needed.
72
- """
73
  params = {
74
  "instId": symbol,
75
  "bar": timeframe,
76
  "limit": limit
77
  }
78
  if after is not None:
79
- # fetch records older than 'after'
80
  params["after"] = str(after)
81
  if before is not None:
82
- # fetch records newer than 'before'
83
  params["before"] = str(before)
84
 
85
- logging.info(f"Fetching chunk: symbol={symbol}, bar={timeframe}, limit={limit}, after={after}, before={before}")
86
  try:
87
  resp = requests.get(OKX_CANDLE_ENDPOINT, params=params, timeout=30)
88
  resp.raise_for_status()
@@ -97,11 +142,7 @@ def fetch_okx_candles_chunk(symbol, timeframe, limit=300, after=None, before=Non
97
  if not items:
98
  return pd.DataFrame(), ""
99
 
100
- # items are newest first. We'll parse them in that order, then we can reverse later.
101
- columns = [
102
- "ts", "o", "h", "l", "c", "vol",
103
- "volCcy", "volCcyQuote", "confirm"
104
- ]
105
  df = pd.DataFrame(items, columns=columns)
106
  df.rename(columns={
107
  "ts": "timestamp",
@@ -121,20 +162,16 @@ def fetch_okx_candles_chunk(symbol, timeframe, limit=300, after=None, before=Non
121
  return pd.DataFrame(), err_msg
122
 
123
 
 
124
  def fetch_okx_candles(symbol, timeframe="1H", total=2000):
125
  """
126
- Fetch ~`total` candles by chaining multiple requests of up to 300 each.
127
- We'll get the newest data first, then request older data in loops,
128
- because 'after' param returns records older than the provided ts.
129
-
130
- Returns df in chronological order (oldest -> newest).
131
  """
132
- logging.info(f"Fetching ~{total} candles for {symbol} @ {timeframe} (in multiple chunks).")
133
 
134
- # We'll do enough calls to get at least `total` data points, or break if no more data.
135
  calls_needed = math.ceil(total / 300.0)
136
  all_data = []
137
- after_ts = None # We'll track the earliest timestamp we see, then pass "after" to go older
138
 
139
  for _ in range(calls_needed):
140
  df_chunk, err = fetch_okx_candles_chunk(
@@ -143,51 +180,38 @@ def fetch_okx_candles(symbol, timeframe="1H", total=2000):
143
  if err:
144
  return pd.DataFrame(), err
145
  if df_chunk.empty:
146
- # No more data
147
  break
148
 
149
- # df_chunk is newest first, so the last row is the earliest in that chunk.
150
  earliest_ts = df_chunk["timestamp"].iloc[-1]
151
- # We'll keep chaining to older data by passing after = earliest_ts-1 (in ms).
152
- # But we need that as a Unix milliseconds integer.
153
  after_ts = int(earliest_ts.timestamp() * 1000 - 1)
154
-
155
- # Add this chunk to the big list
156
  all_data.append(df_chunk)
157
 
158
  if len(df_chunk) < 300:
159
- # We didn't get a full chunk, means no more older data available
160
  break
161
 
162
- # Concatenate everything
163
  if not all_data:
164
- logging.info("No data returned overall.")
165
  return pd.DataFrame(), "No data returned."
166
 
167
  df_all = pd.concat(all_data, ignore_index=True)
168
- # Each chunk is newest first, so the entire df is a bunch of blocks newest->oldest blocks.
169
- # Let's invert the final large df to chronological
170
  df_all.sort_values(by="timestamp", inplace=True)
171
  df_all.reset_index(drop=True, inplace=True)
172
- logging.info(f"Fetched a total of {len(df_all)} rows for {symbol}.")
 
 
 
 
173
  return df_all, ""
174
 
175
-
176
  ########################################
177
- # Prophet pipeline
178
  ########################################
179
 
180
  def prepare_data_for_prophet(df):
181
- """
182
- Convert DataFrame to Prophet-compatible format: columns ds, y.
183
- """
184
  if df.empty:
185
- logging.warning("Empty DataFrame, cannot prepare data for Prophet.")
186
  return pd.DataFrame(columns=["ds", "y"])
187
  df_prophet = df.rename(columns={"timestamp": "ds", "close": "y"})
188
  return df_prophet[["ds", "y"]]
189
 
190
-
191
  def prophet_forecast(
192
  df_prophet,
193
  periods=10,
@@ -198,15 +222,8 @@ def prophet_forecast(
198
  seasonality_mode="additive",
199
  changepoint_prior_scale=0.05,
200
  ):
201
- """
202
- Train a Prophet model with various exposed settings:
203
- - daily/weekly/yearly seasonality toggles
204
- - seasonality_mode ("additive" or "multiplicative")
205
- - changepoint_prior_scale (0.01 to ~10, controls overfitting)
206
- """
207
  if df_prophet.empty:
208
- logging.warning("No data for Prophet.")
209
- return pd.DataFrame(), "No data to forecast."
210
 
211
  try:
212
  model = Prophet(
@@ -225,6 +242,8 @@ def prophet_forecast(
225
  return pd.DataFrame(), f"Forecast error: {e}"
226
 
227
 
 
 
228
  def prophet_wrapper(
229
  df_prophet,
230
  forecast_steps,
@@ -235,9 +254,6 @@ def prophet_wrapper(
235
  seasonality_mode,
236
  changepoint_prior_scale,
237
  ):
238
- """
239
- Run the forecast with user-chosen settings, then keep future (new) rows only.
240
- """
241
  if len(df_prophet) < 10:
242
  return pd.DataFrame(), "Not enough data for forecasting (need >=10 rows)."
243
 
@@ -254,21 +270,16 @@ def prophet_wrapper(
254
  if err:
255
  return pd.DataFrame(), err
256
 
257
- # Future portion only: the new rows after the original data
258
  future_only = full_forecast.loc[len(df_prophet):, ["ds", "yhat", "yhat_lower", "yhat_upper"]]
259
  return future_only, ""
260
 
261
-
262
  ########################################
263
- # Plot helper
264
  ########################################
265
 
266
- def create_line_plot(forecast_df):
267
- """
268
- Make a Plotly line chart from forecast.
269
- """
270
  if forecast_df.empty:
271
- return go.Figure() # empty figure
272
 
273
  fig = go.Figure()
274
  fig.add_trace(go.Scatter(
@@ -276,40 +287,45 @@ def create_line_plot(forecast_df):
276
  y=forecast_df["yhat"],
277
  mode="lines",
278
  name="Forecast",
279
- line=dict(color="blue")
280
  ))
281
 
282
- # Lower bound
283
  fig.add_trace(go.Scatter(
284
  x=forecast_df["ds"],
285
  y=forecast_df["yhat_lower"],
286
  fill=None,
287
  mode="lines",
288
- line=dict(width=0, color="lightblue"),
289
- name="Lower"
 
290
  ))
291
 
292
- # Upper bound
293
  fig.add_trace(go.Scatter(
294
  x=forecast_df["ds"],
295
  y=forecast_df["yhat_upper"],
296
  fill="tonexty",
297
  mode="lines",
298
- line=dict(width=0, color="lightblue"),
299
- name="Upper"
300
  ))
301
 
302
  fig.update_layout(
303
- title="Forecasted Prices",
304
- xaxis_title="Timestamp",
305
  yaxis_title="Price",
306
- hovermode="x"
 
 
 
 
 
 
 
307
  )
308
  return fig
309
 
310
-
311
  ########################################
312
- # Main Gradio logic
313
  ########################################
314
 
315
  def predict(
@@ -323,23 +339,12 @@ def predict(
323
  seasonality_mode,
324
  changepoint_prior_scale,
325
  ):
326
- """
327
- 1) Fetch `total_candles` historical data (in multiple parts if needed)
328
- 2) Convert to Prophet style
329
- 3) Run forecast with user-specified Prophet settings
330
- 4) Return future portion
331
- """
332
- # Convert timeframe to OKX style
333
  okx_bar = TIMEFRAME_MAPPING.get(timeframe, "1H")
334
-
335
- # This fetch can yield thousands of candles
336
  df_raw, err = fetch_okx_candles(symbol, timeframe=okx_bar, total=total_candles)
337
  if err:
338
- return pd.DataFrame(), err
339
 
340
  df_prophet = prepare_data_for_prophet(df_raw)
341
-
342
- # Decide Prophet frequency
343
  freq = "h" if "h" in timeframe.lower() else "d"
344
 
345
  future_df, err2 = prophet_wrapper(
@@ -353,9 +358,10 @@ def predict(
353
  changepoint_prior_scale,
354
  )
355
  if err2:
356
- return pd.DataFrame(), err2
 
 
357
 
358
- return future_df, ""
359
 
360
 
361
  def display_forecast(
@@ -369,12 +375,9 @@ def display_forecast(
369
  seasonality_mode,
370
  changepoint_prior_scale,
371
  ):
372
- logging.info(
373
- f"User requested: symbol={symbol}, timeframe={timeframe}, steps={forecast_steps}, "
374
- f"total_candles={total_candles}, daily={daily_seasonality}, weekly={weekly_seasonality}, "
375
- f"yearly={yearly_seasonality}, mode={seasonality_mode}, cps={changepoint_prior_scale}"
376
- )
377
- forecast_df, error = predict(
378
  symbol,
379
  timeframe,
380
  forecast_steps,
@@ -385,80 +388,95 @@ def display_forecast(
385
  seasonality_mode,
386
  changepoint_prior_scale,
387
  )
 
388
  if error:
389
- return None, f"Error: {error}"
390
-
391
- fig = create_line_plot(forecast_df)
392
- return fig, forecast_df
393
 
 
 
 
 
394
 
395
  def main():
396
- # Fetch OKX symbols
397
  symbols = fetch_okx_symbols()
398
- if not symbols or "Error" in symbols[0]:
399
- symbols = ["No symbols available"]
400
-
401
- with gr.Blocks() as demo:
402
- gr.Markdown("# Crypto Price Forecasting with Prophet")
403
- gr.Markdown(
404
- "This tool can gather thousands of historical data points from OKX's spot market "
405
- "and make forecasts using Prophet. You can tweak Prophet's advanced settings or "
406
- "increase the candle fetch size for potentially more accurate predictions.\n\n"
407
- "Simply pick a symbol, timeframe, how many candles (max ~2000), and forecast steps."
408
- )
409
-
410
- # Input controls
411
- symbol_dd = gr.Dropdown(
412
- label="Symbol",
413
- choices=symbols,
414
- value=symbols[0] if symbols else None
415
- )
416
- timeframe_dd = gr.Dropdown(
417
- label="Timeframe",
418
- choices=["1m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "12h", "1d", "1w"],
419
- value="1h"
420
- )
421
- total_candles_slider = gr.Slider(
422
- label="Total Candles to Fetch",
423
- minimum=300,
424
- maximum=3000,
425
- value=2000,
426
- step=100
427
- )
428
- forecast_steps_slider = gr.Slider(
429
- label="Forecast Steps",
430
- minimum=1,
431
- maximum=100,
432
- value=10
433
- )
434
-
435
- # Prophet advanced settings
436
- daily_box = gr.Checkbox(label="Daily Seasonality", value=False)
437
- weekly_box = gr.Checkbox(label="Weekly Seasonality", value=False)
438
- yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False)
439
- seasonality_mode_dd = gr.Dropdown(
440
- label="Seasonality Mode",
441
- choices=["additive", "multiplicative"],
442
- value="additive"
443
- )
444
- changepoint_scale_slider = gr.Slider(
445
- label="Changepoint Prior Scale (0.01 ~ 1.0)",
446
- minimum=0.01,
447
- maximum=1.0,
448
- step=0.01,
449
- value=0.05
450
- )
451
-
452
- forecast_btn = gr.Button("Generate Forecast")
453
-
454
- # Outputs
455
- chart_output = gr.Plot(label="Forecast Chart")
456
- df_output = gr.Dataframe(
457
- label="Forecast (Future Only)",
458
- headers=["ds", "yhat", "yhat_lower", "yhat_upper"]
459
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
- # Hook everything up
462
  forecast_btn.click(
463
  fn=display_forecast,
464
  inputs=[
@@ -472,17 +490,11 @@ def main():
472
  seasonality_mode_dd,
473
  changepoint_scale_slider,
474
  ],
475
- outputs=[chart_output, df_output]
476
- )
477
-
478
- # You can choose any text variation you like here
479
- gr.Markdown(
480
- "For automated trading tools, consider Gunbot as your next [crypto trading bot](https://www.gunbot.com)."
481
  )
482
 
483
  return demo
484
 
485
-
486
  if __name__ == "__main__":
487
  app = main()
488
- app.launch()
 
5
  import logging
6
  import plotly.graph_objs as go
7
  import math
8
+ import numpy as np
9
 
10
  logging.basicConfig(level=logging.INFO)
11
 
 
16
  OKX_TICKERS_ENDPOINT = "https://www.okx.com/api/v5/market/tickers?instType=SPOT"
17
  OKX_CANDLE_ENDPOINT = "https://www.okx.com/api/v5/market/candles"
18
 
 
19
  TIMEFRAME_MAPPING = {
20
  "1m": "1m",
21
  "5m": "5m",
 
30
  "1w": "1W",
31
  }
32
 
33
+ def calculate_technical_indicators(df):
34
+ # Calculate RSI
35
+ delta = df['close'].diff()
36
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
37
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
38
+ rs = gain / loss
39
+ df['RSI'] = 100 - (100 / (1 + rs))
40
+
41
+ # Calculate MACD
42
+ exp1 = df['close'].ewm(span=12, adjust=False).mean()
43
+ exp2 = df['close'].ewm(span=26, adjust=False).mean()
44
+ df['MACD'] = exp1 - exp2
45
+ df['Signal_Line'] = df['MACD'].ewm(span=9, adjust=False).mean()
46
+
47
+ # Calculate Bollinger Bands
48
+ df['MA20'] = df['close'].rolling(window=20).mean()
49
+ df['BB_upper'] = df['MA20'] + 2 * df['close'].rolling(window=20).std()
50
+ df['BB_lower'] = df['MA20'] - 2 * df['close'].rolling(window=20).std()
51
+
52
+ return df
53
+
54
+ def create_technical_charts(df):
55
+ # Price and Bollinger Bands
56
+ fig1 = go.Figure()
57
+ fig1.add_trace(go.Candlestick(
58
+ x=df['timestamp'],
59
+ open=df['open'],
60
+ high=df['high'],
61
+ low=df['low'],
62
+ close=df['close'],
63
+ name='Price'
64
+ ))
65
+ fig1.add_trace(go.Scatter(x=df['timestamp'], y=df['BB_upper'], name='Upper BB', line=dict(color='gray', dash='dash')))
66
+ fig1.add_trace(go.Scatter(x=df['timestamp'], y=df['BB_lower'], name='Lower BB', line=dict(color='gray', dash='dash')))
67
+ fig1.update_layout(title='Price and Bollinger Bands', xaxis_title='Date', yaxis_title='Price')
68
+
69
+ # RSI
70
+ fig2 = go.Figure()
71
+ fig2.add_trace(go.Scatter(x=df['timestamp'], y=df['RSI'], name='RSI'))
72
+ fig2.add_hline(y=70, line_dash="dash", line_color="red")
73
+ fig2.add_hline(y=30, line_dash="dash", line_color="green")
74
+ fig2.update_layout(title='RSI Indicator', xaxis_title='Date', yaxis_title='RSI')
75
+
76
+ # MACD
77
+ fig3 = go.Figure()
78
+ fig3.add_trace(go.Scatter(x=df['timestamp'], y=df['MACD'], name='MACD'))
79
+ fig3.add_trace(go.Scatter(x=df['timestamp'], y=df['Signal_Line'], name='Signal Line'))
80
+ fig3.update_layout(title='MACD', xaxis_title='Date', yaxis_title='Value')
81
+
82
+ return fig1, fig2, fig3
83
+
84
+ ########################################
85
+ # OKX Data Fetching Functions
86
  ########################################
87
 
88
  def fetch_okx_symbols():
 
97
 
98
  if json_data.get("code") != "0":
99
  logging.error(f"Non-zero code returned: {json_data}")
100
+ return ["BTC-USDT"] # Default fallback
101
 
102
  data = json_data.get("data", [])
103
  symbols = [item["instId"] for item in data if item.get("instType") == "SPOT"]
104
  if not symbols:
105
+ return ["BTC-USDT"]
 
106
 
107
+ # Ensure BTC-USDT is first in the list
108
+ if "BTC-USDT" in symbols:
109
+ symbols.remove("BTC-USDT")
110
+ symbols.insert(0, "BTC-USDT")
111
+
112
  logging.info(f"Fetched {len(symbols)} OKX spot symbols.")
113
+ return symbols
114
 
115
  except Exception as e:
116
  logging.error(f"Error fetching OKX symbols: {e}")
117
+ return ["BTC-USDT"]
 
118
 
119
  def fetch_okx_candles_chunk(symbol, timeframe, limit=300, after=None, before=None):
 
 
 
 
 
 
 
120
  params = {
121
  "instId": symbol,
122
  "bar": timeframe,
123
  "limit": limit
124
  }
125
  if after is not None:
 
126
  params["after"] = str(after)
127
  if before is not None:
 
128
  params["before"] = str(before)
129
 
130
+ logging.info(f"Fetching chunk: symbol={symbol}, bar={timeframe}, limit={limit}")
131
  try:
132
  resp = requests.get(OKX_CANDLE_ENDPOINT, params=params, timeout=30)
133
  resp.raise_for_status()
 
142
  if not items:
143
  return pd.DataFrame(), ""
144
 
145
+ columns = ["ts", "o", "h", "l", "c", "vol", "volCcy", "volCcyQuote", "confirm"]
 
 
 
 
146
  df = pd.DataFrame(items, columns=columns)
147
  df.rename(columns={
148
  "ts": "timestamp",
 
162
  return pd.DataFrame(), err_msg
163
 
164
 
165
+
166
  def fetch_okx_candles(symbol, timeframe="1H", total=2000):
167
  """
168
+ Fetch historical candle data
 
 
 
 
169
  """
170
+ logging.info(f"Fetching ~{total} candles for {symbol} @ {timeframe}")
171
 
 
172
  calls_needed = math.ceil(total / 300.0)
173
  all_data = []
174
+ after_ts = None
175
 
176
  for _ in range(calls_needed):
177
  df_chunk, err = fetch_okx_candles_chunk(
 
180
  if err:
181
  return pd.DataFrame(), err
182
  if df_chunk.empty:
 
183
  break
184
 
 
185
  earliest_ts = df_chunk["timestamp"].iloc[-1]
 
 
186
  after_ts = int(earliest_ts.timestamp() * 1000 - 1)
 
 
187
  all_data.append(df_chunk)
188
 
189
  if len(df_chunk) < 300:
 
190
  break
191
 
 
192
  if not all_data:
 
193
  return pd.DataFrame(), "No data returned."
194
 
195
  df_all = pd.concat(all_data, ignore_index=True)
 
 
196
  df_all.sort_values(by="timestamp", inplace=True)
197
  df_all.reset_index(drop=True, inplace=True)
198
+
199
+ # Calculate technical indicators
200
+ df_all = calculate_technical_indicators(df_all)
201
+
202
+ logging.info(f"Fetched {len(df_all)} rows for {symbol}.")
203
  return df_all, ""
204
 
 
205
  ########################################
206
+ # Prophet Pipeline
207
  ########################################
208
 
209
  def prepare_data_for_prophet(df):
 
 
 
210
  if df.empty:
 
211
  return pd.DataFrame(columns=["ds", "y"])
212
  df_prophet = df.rename(columns={"timestamp": "ds", "close": "y"})
213
  return df_prophet[["ds", "y"]]
214
 
 
215
  def prophet_forecast(
216
  df_prophet,
217
  periods=10,
 
222
  seasonality_mode="additive",
223
  changepoint_prior_scale=0.05,
224
  ):
 
 
 
 
 
 
225
  if df_prophet.empty:
226
+ return pd.DataFrame(), "No data for Prophet."
 
227
 
228
  try:
229
  model = Prophet(
 
242
  return pd.DataFrame(), f"Forecast error: {e}"
243
 
244
 
245
+
246
+
247
  def prophet_wrapper(
248
  df_prophet,
249
  forecast_steps,
 
254
  seasonality_mode,
255
  changepoint_prior_scale,
256
  ):
 
 
 
257
  if len(df_prophet) < 10:
258
  return pd.DataFrame(), "Not enough data for forecasting (need >=10 rows)."
259
 
 
270
  if err:
271
  return pd.DataFrame(), err
272
 
 
273
  future_only = full_forecast.loc[len(df_prophet):, ["ds", "yhat", "yhat_lower", "yhat_upper"]]
274
  return future_only, ""
275
 
 
276
  ########################################
277
+ # Plotting Functions
278
  ########################################
279
 
280
+ def create_forecast_plot(forecast_df):
 
 
 
281
  if forecast_df.empty:
282
+ return go.Figure()
283
 
284
  fig = go.Figure()
285
  fig.add_trace(go.Scatter(
 
287
  y=forecast_df["yhat"],
288
  mode="lines",
289
  name="Forecast",
290
+ line=dict(color="blue", width=2)
291
  ))
292
 
 
293
  fig.add_trace(go.Scatter(
294
  x=forecast_df["ds"],
295
  y=forecast_df["yhat_lower"],
296
  fill=None,
297
  mode="lines",
298
+ line=dict(width=0),
299
+ showlegend=True,
300
+ name="Lower Bound"
301
  ))
302
 
 
303
  fig.add_trace(go.Scatter(
304
  x=forecast_df["ds"],
305
  y=forecast_df["yhat_upper"],
306
  fill="tonexty",
307
  mode="lines",
308
+ line=dict(width=0),
309
+ name="Upper Bound"
310
  ))
311
 
312
  fig.update_layout(
313
+ title="Price Forecast",
314
+ xaxis_title="Time",
315
  yaxis_title="Price",
316
+ hovermode="x unified",
317
+ template="plotly_white",
318
+ legend=dict(
319
+ yanchor="top",
320
+ y=0.99,
321
+ xanchor="left",
322
+ x=0.01
323
+ )
324
  )
325
  return fig
326
 
 
327
  ########################################
328
+ # Main Prediction Function
329
  ########################################
330
 
331
  def predict(
 
339
  seasonality_mode,
340
  changepoint_prior_scale,
341
  ):
 
 
 
 
 
 
 
342
  okx_bar = TIMEFRAME_MAPPING.get(timeframe, "1H")
 
 
343
  df_raw, err = fetch_okx_candles(symbol, timeframe=okx_bar, total=total_candles)
344
  if err:
345
+ return pd.DataFrame(), pd.DataFrame(), err
346
 
347
  df_prophet = prepare_data_for_prophet(df_raw)
 
 
348
  freq = "h" if "h" in timeframe.lower() else "d"
349
 
350
  future_df, err2 = prophet_wrapper(
 
358
  changepoint_prior_scale,
359
  )
360
  if err2:
361
+ return pd.DataFrame(), pd.DataFrame(), err2
362
+
363
+ return df_raw, future_df, ""
364
 
 
365
 
366
 
367
  def display_forecast(
 
375
  seasonality_mode,
376
  changepoint_prior_scale,
377
  ):
378
+ logging.info(f"Processing forecast request for {symbol}")
379
+
380
+ df_raw, forecast_df, error = predict(
 
 
 
381
  symbol,
382
  timeframe,
383
  forecast_steps,
 
388
  seasonality_mode,
389
  changepoint_prior_scale,
390
  )
391
+
392
  if error:
393
+ return None, None, None, None, f"Error: {error}"
 
 
 
394
 
395
+ forecast_plot = create_forecast_plot(forecast_df)
396
+ tech_plot, rsi_plot, macd_plot = create_technical_charts(df_raw)
397
+
398
+ return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df
399
 
400
  def main():
 
401
  symbols = fetch_okx_symbols()
402
+
403
+ with gr.Blocks(theme=gr.themes.Base()) as demo:
404
+ with gr.Row():
405
+ gr.Markdown("# Cryptocurrency Price Forecasting System")
406
+
407
+ with gr.Row():
408
+ with gr.Column(scale=1):
409
+ with gr.Box():
410
+ gr.Markdown("### Market Selection")
411
+ symbol_dd = gr.Dropdown(
412
+ label="Trading Pair",
413
+ choices=symbols,
414
+ value="BTC-USDT"
415
+ )
416
+ timeframe_dd = gr.Dropdown(
417
+ label="Timeframe",
418
+ choices=list(TIMEFRAME_MAPPING.keys()),
419
+ value="1h"
420
+ )
421
+
422
+ with gr.Column(scale=1):
423
+ with gr.Box():
424
+ gr.Markdown("### Forecast Parameters")
425
+ forecast_steps_slider = gr.Slider(
426
+ label="Forecast Steps",
427
+ minimum=1,
428
+ maximum=100,
429
+ value=24,
430
+ step=1
431
+ )
432
+ total_candles_slider = gr.Slider(
433
+ label="Historical Candles",
434
+ minimum=300,
435
+ maximum=3000,
436
+ value=2000,
437
+ step=100
438
+ )
439
+
440
+ with gr.Row():
441
+ with gr.Column():
442
+ with gr.Box():
443
+ gr.Markdown("### Advanced Settings")
444
+ with gr.Row():
445
+ daily_box = gr.Checkbox(label="Daily Seasonality", value=True)
446
+ weekly_box = gr.Checkbox(label="Weekly Seasonality", value=True)
447
+ yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False)
448
+ seasonality_mode_dd = gr.Dropdown(
449
+ label="Seasonality Mode",
450
+ choices=["additive", "multiplicative"],
451
+ value="additive"
452
+ )
453
+ changepoint_scale_slider = gr.Slider(
454
+ label="Changepoint Prior Scale",
455
+ minimum=0.01,
456
+ maximum=1.0,
457
+ step=0.01,
458
+ value=0.05
459
+ )
460
+
461
+ with gr.Row():
462
+ forecast_btn = gr.Button("Generate Forecast", variant="primary", size="lg")
463
+
464
+ with gr.Row():
465
+ forecast_plot = gr.Plot(label="Price Forecast")
466
+
467
+ with gr.Row():
468
+ tech_plot = gr.Plot(label="Technical Analysis")
469
+ rsi_plot = gr.Plot(label="RSI Indicator")
470
+
471
+ with gr.Row():
472
+ macd_plot = gr.Plot(label="MACD")
473
+
474
+ with gr.Row():
475
+ forecast_df = gr.Dataframe(
476
+ label="Forecast Data",
477
+ headers=["Date", "Forecast", "Lower Bound", "Upper Bound"]
478
+ )
479
 
 
480
  forecast_btn.click(
481
  fn=display_forecast,
482
  inputs=[
 
490
  seasonality_mode_dd,
491
  changepoint_scale_slider,
492
  ],
493
+ outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df]
 
 
 
 
 
494
  )
495
 
496
  return demo
497
 
 
498
  if __name__ == "__main__":
499
  app = main()
500
+ app.launch()