Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import requests | |
from prophet import Prophet | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
######################################## | |
# OKX endpoints & utility | |
######################################## | |
# 1) GET symbols (spot tickers) | |
OKX_TICKERS_ENDPOINT = "https://www.okx.com/api/v5/market/tickers?instType=SPOT" | |
# 2) GET historical candles for a symbol | |
# e.g. https://www.okx.com/api/v5/market/candles?instId=BTC-USDT&bar=1H&limit=100 | |
OKX_CANDLE_ENDPOINT = "https://www.okx.com/api/v5/market/candles" | |
# You can extend or modify this to match more of OKX's `bar` intervals | |
TIMEFRAME_MAPPING = { | |
"1m": "1m", | |
"5m": "5m", | |
"15m": "15m", | |
"30m": "30m", | |
"1h": "1H", | |
"2h": "2H", | |
"4h": "4H", | |
"6h": "6H", | |
"12h": "12H", | |
"1d": "1D", | |
"1w": "1W", # OKX supports 1W, etc. | |
} | |
def fetch_okx_symbols(): | |
""" | |
Fetch the list of symbols (instId) from OKX Spot tickers. | |
""" | |
logging.info("Fetching symbols from OKX Spot tickers...") | |
try: | |
resp = requests.get(OKX_TICKERS_ENDPOINT, timeout=30) | |
resp.raise_for_status() | |
json_data = resp.json() | |
if json_data.get("code") != "0": | |
logging.error(f"Non-zero code returned: {json_data}") | |
return ["Error: Could not fetch OKX symbols"] | |
data = json_data.get("data", []) | |
# Example item in data: { "instId": "ETH-USDT", "instType": "SPOT", ... } | |
symbols = [item["instId"] for item in data if item.get("instType") == "SPOT"] | |
if not symbols: | |
logging.warning("No spot symbols found.") | |
return ["Error: No spot symbols found."] | |
logging.info(f"Fetched {len(symbols)} OKX spot symbols.") | |
return sorted(symbols) | |
except Exception as e: | |
logging.error(f"Error fetching OKX symbols: {e}") | |
return [f"Error: {str(e)}"] | |
def fetch_okx_candles(symbol, timeframe="1H", limit=100): | |
""" | |
Fetch historical candle data for a symbol from OKX. | |
timeframe must match OKX's `bar` (e.g. "1H", "4H", "1D"). | |
Returns (DataFrame, error_message) or (DataFrame, ""). | |
""" | |
logging.info(f"Fetching {limit} candles for {symbol} @ {timeframe} from OKX...") | |
params = { | |
"instId": symbol, | |
"bar": timeframe, | |
"limit": limit | |
} | |
try: | |
resp = requests.get(OKX_CANDLE_ENDPOINT, params=params, timeout=30) | |
resp.raise_for_status() | |
json_data = resp.json() | |
if json_data.get("code") != "0": | |
msg = f"OKX returned code={json_data.get('code')}, msg={json_data.get('msg')}" | |
logging.error(msg) | |
return pd.DataFrame(), msg | |
# Data looks like: ["1673684400000", "20923.7", "20952.5", "20881.3", "20945.8", "927.879", "19412314.5671"] | |
# Let's parse columns: [0] ts, [1] open, [2] high, [3] low, [4] close, [5] volume, [6] ??? quoteVol | |
items = json_data.get("data", []) | |
if not items: | |
warning_msg = f"No candle data returned for {symbol}." | |
logging.warning(warning_msg) | |
return pd.DataFrame(), warning_msg | |
# items is a list of lists, each is a candle. Reverse if needed to go old->new: | |
# OKX returns the most recent data first, so we invert it for chronological order | |
items.reverse() | |
df = pd.DataFrame(items, columns=[ | |
"timestamp", "open", "high", "low", "close", "volume", "quoteVolume" | |
]) | |
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms") | |
df[["open", "high", "low", "close", "volume", "quoteVolume"]] = df[ | |
["open", "high", "low", "close", "volume", "quoteVolume"] | |
].astype(float) | |
logging.info(f"Fetched {len(df)} rows for {symbol}.") | |
return df, "" | |
except Exception as e: | |
err_msg = f"Error fetching candles for {symbol}: {e}" | |
logging.error(err_msg) | |
return pd.DataFrame(), err_msg | |
######################################## | |
# Prophet pipeline | |
######################################## | |
def prepare_data_for_prophet(df): | |
""" | |
Convert the DataFrame to a Prophet-compatible format. | |
""" | |
if df.empty: | |
logging.warning("Empty DataFrame, cannot prepare data for Prophet.") | |
return pd.DataFrame(columns=["ds", "y"]) | |
df_prophet = df.rename(columns={"timestamp": "ds", "close": "y"}) | |
return df_prophet[["ds", "y"]] | |
def prophet_forecast(df_prophet, periods=10, freq="H"): | |
""" | |
Train a Prophet model and forecast. | |
""" | |
if df_prophet.empty: | |
logging.warning("Prophet input is empty, no forecast can be generated.") | |
return pd.DataFrame(), "No data to forecast." | |
try: | |
model = Prophet() | |
model.fit(df_prophet) | |
future = model.make_future_dataframe(periods=periods, freq=freq) | |
forecast = model.predict(future) | |
return forecast, "" | |
except Exception as e: | |
logging.error(f"Forecast error: {e}") | |
return pd.DataFrame(), f"Forecast error: {e}" | |
def prophet_wrapper(df_prophet, forecast_steps, freq): | |
""" | |
Do the forecast, then slice out the new/future rows. | |
""" | |
if len(df_prophet) < 10: | |
return pd.DataFrame(), "Not enough data for forecasting (need >=10 rows)." | |
full_forecast, err = prophet_forecast(df_prophet, forecast_steps, freq) | |
if err: | |
return pd.DataFrame(), err | |
# Only keep the newly generated future portion | |
future_only = full_forecast.iloc[len(df_prophet):, ["ds", "yhat", "yhat_lower", "yhat_upper"]] | |
return future_only, "" | |
######################################## | |
# Main Gradio logic | |
######################################## | |
def predict(symbol, timeframe, forecast_steps): | |
""" | |
Orchestrate candle fetch + prophet forecast. | |
""" | |
# Convert user timeframe to OKX bar param | |
okx_bar = TIMEFRAME_MAPPING.get(timeframe, "1H") | |
# Let’s fetch 500 candles | |
df_raw, err = fetch_okx_candles(symbol, timeframe=okx_bar, limit=500) | |
if err: | |
return pd.DataFrame(), err | |
df_prophet = prepare_data_for_prophet(df_raw) | |
# We guess frequency from timeframe. If timeframe is "1h", we'll do freq="H" in Prophet, etc. | |
# We'll do a simple mapping here: | |
freq = "H" if "h" in timeframe.lower() else "D" # e.g. "1h" -> "H", "1d" -> "D" | |
future_df, err2 = prophet_wrapper(df_prophet, forecast_steps, freq) | |
if err2: | |
return pd.DataFrame(), err2 | |
return future_df, "" | |
def display_forecast(symbol, timeframe, forecast_steps): | |
""" | |
For the Gradio UI, returns forecast or error message. | |
""" | |
logging.info(f"User requested: symbol={symbol}, timeframe={timeframe}, steps={forecast_steps}") | |
forecast_df, error = predict(symbol, timeframe, forecast_steps) | |
if error: | |
return f"Error: {error}" | |
return forecast_df | |
def main(): | |
# Fetch OKX symbols | |
symbols = fetch_okx_symbols() | |
if not symbols or "Error" in symbols[0]: | |
symbols = ["No symbols available"] | |
with gr.Blocks() as demo: | |
gr.Markdown("# OKX Price Forecasting with Prophet") | |
gr.Markdown( | |
"This app uses OKX's spot market candles to predict future price movements. " | |
"Select a symbol and timeframe, specify forecast steps, then click 'Generate Forecast'. " | |
"No proxies or special access required." | |
) | |
symbol_dd = gr.Dropdown( | |
label="Symbol", | |
choices=symbols, | |
value=symbols[0] if symbols else None | |
) | |
timeframe_dd = gr.Dropdown( | |
label="Timeframe", | |
choices=["1m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "12h", "1d", "1w"], | |
value="1h" | |
) | |
steps_slider = gr.Slider( | |
label="Forecast Steps (hours/days depending on timeframe)", | |
minimum=1, | |
maximum=100, | |
value=10 | |
) | |
forecast_btn = gr.Button("Generate Forecast") | |
output_df = gr.Dataframe( | |
label="Future Forecast Only", | |
headers=["ds", "yhat", "yhat_lower", "yhat_upper"] | |
) | |
forecast_btn.click( | |
fn=display_forecast, | |
inputs=[symbol_dd, timeframe_dd, steps_slider], | |
outputs=output_df | |
) | |
gr.Markdown( | |
"Looking for more automation? Check out this " | |
"[crypto trading bot](https://www.gunbot.com)." | |
) | |
return demo | |
if __name__ == "__main__": | |
app = main() | |
app.launch() | |