Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import requests | |
from prophet import Prophet | |
import logging | |
import plotly.graph_objs as go | |
import math | |
logging.basicConfig(level=logging.INFO) | |
######################################## | |
# OKX endpoints & utility | |
######################################## | |
OKX_TICKERS_ENDPOINT = "https://www.okx.com/api/v5/market/tickers?instType=SPOT" | |
OKX_CANDLE_ENDPOINT = "https://www.okx.com/api/v5/market/candles" | |
# Allowed bar intervals on OKX, maximum 300 records at a time | |
TIMEFRAME_MAPPING = { | |
"1m": "1m", | |
"5m": "5m", | |
"15m": "15m", | |
"30m": "30m", | |
"1h": "1H", | |
"2h": "2H", | |
"4h": "4H", | |
"6h": "6H", | |
"12h": "12H", | |
"1d": "1D", | |
"1w": "1W", | |
} | |
######################################## | |
# Functions to fetch data from OKX | |
######################################## | |
def fetch_okx_symbols(): | |
""" | |
Fetch spot symbols from OKX. | |
""" | |
logging.info("Fetching symbols from OKX Spot tickers...") | |
try: | |
resp = requests.get(OKX_TICKERS_ENDPOINT, timeout=30) | |
resp.raise_for_status() | |
json_data = resp.json() | |
if json_data.get("code") != "0": | |
logging.error(f"Non-zero code returned: {json_data}") | |
return ["Error: Could not fetch OKX symbols"] | |
data = json_data.get("data", []) | |
symbols = [item["instId"] for item in data if item.get("instType") == "SPOT"] | |
if not symbols: | |
logging.warning("No spot symbols found.") | |
return ["Error: No spot symbols found."] | |
logging.info(f"Fetched {len(symbols)} OKX spot symbols.") | |
return sorted(symbols) | |
except Exception as e: | |
logging.error(f"Error fetching OKX symbols: {e}") | |
return [f"Error: {str(e)}"] | |
def fetch_okx_candles_chunk(symbol, timeframe, limit=300, after=None, before=None): | |
""" | |
Fetch up to `limit` candles (max 300) for the given symbol/timeframe. | |
Optionally use `after` or `before` to page through older or newer data. | |
OKX returns newest data first. The result here is also newest first. | |
We'll reorder or combine them later as needed. | |
""" | |
params = { | |
"instId": symbol, | |
"bar": timeframe, | |
"limit": limit | |
} | |
if after is not None: | |
# fetch records older than 'after' | |
params["after"] = str(after) | |
if before is not None: | |
# fetch records newer than 'before' | |
params["before"] = str(before) | |
logging.info(f"Fetching chunk: symbol={symbol}, bar={timeframe}, limit={limit}, after={after}, before={before}") | |
try: | |
resp = requests.get(OKX_CANDLE_ENDPOINT, params=params, timeout=30) | |
resp.raise_for_status() | |
json_data = resp.json() | |
if json_data.get("code") != "0": | |
msg = f"OKX returned code={json_data.get('code')}, msg={json_data.get('msg')}" | |
logging.error(msg) | |
return pd.DataFrame(), msg | |
items = json_data.get("data", []) | |
if not items: | |
return pd.DataFrame(), "" | |
# items are newest first. We'll parse them in that order, then we can reverse later. | |
columns = [ | |
"ts", "o", "h", "l", "c", "vol", | |
"volCcy", "volCcyQuote", "confirm" | |
] | |
df = pd.DataFrame(items, columns=columns) | |
df.rename(columns={ | |
"ts": "timestamp", | |
"o": "open", | |
"h": "high", | |
"l": "low", | |
"c": "close" | |
}, inplace=True) | |
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms") | |
numeric_cols = ["open", "high", "low", "close", "vol", "volCcy", "volCcyQuote", "confirm"] | |
df[numeric_cols] = df[numeric_cols].astype(float) | |
return df, "" | |
except Exception as e: | |
err_msg = f"Error fetching candles chunk for {symbol}: {e}" | |
logging.error(err_msg) | |
return pd.DataFrame(), err_msg | |
def fetch_okx_candles(symbol, timeframe="1H", total=2000): | |
""" | |
Fetch ~`total` candles by chaining multiple requests of up to 300 each. | |
We'll get the newest data first, then request older data in loops, | |
because 'after' param returns records older than the provided ts. | |
Returns df in chronological order (oldest -> newest). | |
""" | |
logging.info(f"Fetching ~{total} candles for {symbol} @ {timeframe} (in multiple chunks).") | |
# We'll do enough calls to get at least `total` data points, or break if no more data. | |
calls_needed = math.ceil(total / 300.0) | |
all_data = [] | |
after_ts = None # We'll track the earliest timestamp we see, then pass "after" to go older | |
for _ in range(calls_needed): | |
df_chunk, err = fetch_okx_candles_chunk( | |
symbol, timeframe, limit=300, after=after_ts | |
) | |
if err: | |
return pd.DataFrame(), err | |
if df_chunk.empty: | |
# No more data | |
break | |
# df_chunk is newest first, so the last row is the earliest in that chunk. | |
earliest_ts = df_chunk["timestamp"].iloc[-1] | |
# We'll keep chaining to older data by passing after = earliest_ts-1 (in ms). | |
# But we need that as a Unix milliseconds integer. | |
after_ts = int(earliest_ts.timestamp() * 1000 - 1) | |
# Add this chunk to the big list | |
all_data.append(df_chunk) | |
if len(df_chunk) < 300: | |
# We didn't get a full chunk, means no more older data available | |
break | |
# Concatenate everything | |
if not all_data: | |
logging.info("No data returned overall.") | |
return pd.DataFrame(), "No data returned." | |
df_all = pd.concat(all_data, ignore_index=True) | |
# Each chunk is newest first, so the entire df is a bunch of blocks newest->oldest blocks. | |
# Let's invert the final large df to chronological | |
df_all.sort_values(by="timestamp", inplace=True) | |
df_all.reset_index(drop=True, inplace=True) | |
logging.info(f"Fetched a total of {len(df_all)} rows for {symbol}.") | |
return df_all, "" | |
######################################## | |
# Prophet pipeline | |
######################################## | |
def prepare_data_for_prophet(df): | |
""" | |
Convert DataFrame to Prophet-compatible format: columns ds, y. | |
""" | |
if df.empty: | |
logging.warning("Empty DataFrame, cannot prepare data for Prophet.") | |
return pd.DataFrame(columns=["ds", "y"]) | |
df_prophet = df.rename(columns={"timestamp": "ds", "close": "y"}) | |
return df_prophet[["ds", "y"]] | |
def prophet_forecast( | |
df_prophet, | |
periods=10, | |
freq="h", | |
daily_seasonality=False, | |
weekly_seasonality=False, | |
yearly_seasonality=False, | |
seasonality_mode="additive", | |
changepoint_prior_scale=0.05, | |
): | |
""" | |
Train a Prophet model with various exposed settings: | |
- daily/weekly/yearly seasonality toggles | |
- seasonality_mode ("additive" or "multiplicative") | |
- changepoint_prior_scale (0.01 to ~10, controls overfitting) | |
""" | |
if df_prophet.empty: | |
logging.warning("No data for Prophet.") | |
return pd.DataFrame(), "No data to forecast." | |
try: | |
model = Prophet( | |
daily_seasonality=daily_seasonality, | |
weekly_seasonality=weekly_seasonality, | |
yearly_seasonality=yearly_seasonality, | |
seasonality_mode=seasonality_mode, | |
changepoint_prior_scale=changepoint_prior_scale, | |
) | |
model.fit(df_prophet) | |
future = model.make_future_dataframe(periods=periods, freq=freq) | |
forecast = model.predict(future) | |
return forecast, "" | |
except Exception as e: | |
logging.error(f"Forecast error: {e}") | |
return pd.DataFrame(), f"Forecast error: {e}" | |
def prophet_wrapper( | |
df_prophet, | |
forecast_steps, | |
freq, | |
daily_seasonality, | |
weekly_seasonality, | |
yearly_seasonality, | |
seasonality_mode, | |
changepoint_prior_scale, | |
): | |
""" | |
Run the forecast with user-chosen settings, then keep future (new) rows only. | |
""" | |
if len(df_prophet) < 10: | |
return pd.DataFrame(), "Not enough data for forecasting (need >=10 rows)." | |
full_forecast, err = prophet_forecast( | |
df_prophet, | |
periods=forecast_steps, | |
freq=freq, | |
daily_seasonality=daily_seasonality, | |
weekly_seasonality=weekly_seasonality, | |
yearly_seasonality=yearly_seasonality, | |
seasonality_mode=seasonality_mode, | |
changepoint_prior_scale=changepoint_prior_scale, | |
) | |
if err: | |
return pd.DataFrame(), err | |
# Future portion only: the new rows after the original data | |
future_only = full_forecast.loc[len(df_prophet):, ["ds", "yhat", "yhat_lower", "yhat_upper"]] | |
return future_only, "" | |
######################################## | |
# Plot helper | |
######################################## | |
def create_line_plot(forecast_df): | |
""" | |
Make a Plotly line chart from forecast. | |
""" | |
if forecast_df.empty: | |
return go.Figure() # empty figure | |
fig = go.Figure() | |
fig.add_trace(go.Scatter( | |
x=forecast_df["ds"], | |
y=forecast_df["yhat"], | |
mode="lines", | |
name="Forecast", | |
line=dict(color="blue") | |
)) | |
# Lower bound | |
fig.add_trace(go.Scatter( | |
x=forecast_df["ds"], | |
y=forecast_df["yhat_lower"], | |
fill=None, | |
mode="lines", | |
line=dict(width=0, color="lightblue"), | |
name="Lower" | |
)) | |
# Upper bound | |
fig.add_trace(go.Scatter( | |
x=forecast_df["ds"], | |
y=forecast_df["yhat_upper"], | |
fill="tonexty", | |
mode="lines", | |
line=dict(width=0, color="lightblue"), | |
name="Upper" | |
)) | |
fig.update_layout( | |
title="Forecasted Prices", | |
xaxis_title="Timestamp", | |
yaxis_title="Price", | |
hovermode="x" | |
) | |
return fig | |
######################################## | |
# Main Gradio logic | |
######################################## | |
def predict( | |
symbol, | |
timeframe, | |
forecast_steps, | |
total_candles, | |
daily_seasonality, | |
weekly_seasonality, | |
yearly_seasonality, | |
seasonality_mode, | |
changepoint_prior_scale, | |
): | |
""" | |
1) Fetch `total_candles` historical data (in multiple parts if needed) | |
2) Convert to Prophet style | |
3) Run forecast with user-specified Prophet settings | |
4) Return future portion | |
""" | |
# Convert timeframe to OKX style | |
okx_bar = TIMEFRAME_MAPPING.get(timeframe, "1H") | |
# This fetch can yield thousands of candles | |
df_raw, err = fetch_okx_candles(symbol, timeframe=okx_bar, total=total_candles) | |
if err: | |
return pd.DataFrame(), err | |
df_prophet = prepare_data_for_prophet(df_raw) | |
# Decide Prophet frequency | |
freq = "h" if "h" in timeframe.lower() else "d" | |
future_df, err2 = prophet_wrapper( | |
df_prophet, | |
forecast_steps, | |
freq, | |
daily_seasonality, | |
weekly_seasonality, | |
yearly_seasonality, | |
seasonality_mode, | |
changepoint_prior_scale, | |
) | |
if err2: | |
return pd.DataFrame(), err2 | |
return future_df, "" | |
def display_forecast( | |
symbol, | |
timeframe, | |
forecast_steps, | |
total_candles, | |
daily_seasonality, | |
weekly_seasonality, | |
yearly_seasonality, | |
seasonality_mode, | |
changepoint_prior_scale, | |
): | |
logging.info( | |
f"User requested: symbol={symbol}, timeframe={timeframe}, steps={forecast_steps}, " | |
f"total_candles={total_candles}, daily={daily_seasonality}, weekly={weekly_seasonality}, " | |
f"yearly={yearly_seasonality}, mode={seasonality_mode}, cps={changepoint_prior_scale}" | |
) | |
forecast_df, error = predict( | |
symbol, | |
timeframe, | |
forecast_steps, | |
total_candles, | |
daily_seasonality, | |
weekly_seasonality, | |
yearly_seasonality, | |
seasonality_mode, | |
changepoint_prior_scale, | |
) | |
if error: | |
return None, f"Error: {error}" | |
fig = create_line_plot(forecast_df) | |
return fig, forecast_df | |
def main(): | |
# Fetch OKX symbols | |
symbols = fetch_okx_symbols() | |
if not symbols or "Error" in symbols[0]: | |
symbols = ["No symbols available"] | |
with gr.Blocks() as demo: | |
gr.Markdown("# Crypto Price Forecasting with Prophet") | |
gr.Markdown( | |
"This tool can gather thousands of historical data points from OKX's spot market " | |
"and make forecasts using Prophet. You can tweak Prophet's advanced settings or " | |
"increase the candle fetch size for potentially more accurate predictions.\n\n" | |
"Simply pick a symbol, timeframe, how many candles (max ~2000), and forecast steps." | |
) | |
# Input controls | |
symbol_dd = gr.Dropdown( | |
label="Symbol", | |
choices=symbols, | |
value=symbols[0] if symbols else None | |
) | |
timeframe_dd = gr.Dropdown( | |
label="Timeframe", | |
choices=["1m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "12h", "1d", "1w"], | |
value="1h" | |
) | |
total_candles_slider = gr.Slider( | |
label="Total Candles to Fetch", | |
minimum=300, | |
maximum=3000, | |
value=2000, | |
step=100 | |
) | |
forecast_steps_slider = gr.Slider( | |
label="Forecast Steps", | |
minimum=1, | |
maximum=100, | |
value=10 | |
) | |
# Prophet advanced settings | |
daily_box = gr.Checkbox(label="Daily Seasonality", value=False) | |
weekly_box = gr.Checkbox(label="Weekly Seasonality", value=False) | |
yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False) | |
seasonality_mode_dd = gr.Dropdown( | |
label="Seasonality Mode", | |
choices=["additive", "multiplicative"], | |
value="additive" | |
) | |
changepoint_scale_slider = gr.Slider( | |
label="Changepoint Prior Scale (0.01 ~ 1.0)", | |
minimum=0.01, | |
maximum=1.0, | |
step=0.01, | |
value=0.05 | |
) | |
forecast_btn = gr.Button("Generate Forecast") | |
# Outputs | |
chart_output = gr.Plot(label="Forecast Chart") | |
df_output = gr.Dataframe( | |
label="Forecast (Future Only)", | |
headers=["ds", "yhat", "yhat_lower", "yhat_upper"] | |
) | |
# Hook everything up | |
forecast_btn.click( | |
fn=display_forecast, | |
inputs=[ | |
symbol_dd, | |
timeframe_dd, | |
forecast_steps_slider, | |
total_candles_slider, | |
daily_box, | |
weekly_box, | |
yearly_box, | |
seasonality_mode_dd, | |
changepoint_scale_slider, | |
], | |
outputs=[chart_output, df_output] | |
) | |
# You can choose any text variation you like here | |
gr.Markdown( | |
"For automated trading tools, consider Gunbot as your next [crypto trading bot](https://www.gunbot.com)." | |
) | |
return demo | |
if __name__ == "__main__": | |
app = main() | |
app.launch() | |