import gradio as gr import pandas as pd import requests from prophet import Prophet import plotly.graph_objs as go import math import numpy as np # Constants for API endpoints OKX_TICKERS_ENDPOINT = "https://www.okx.com/api/v5/market/tickers?instType=SPOT" OKX_CANDLE_ENDPOINT = "https://www.okx.com/api/v5/market/candles" TIMEFRAME_MAPPING = { "1m": "1m", "5m": "5m", "15m": "15m", "30m": "30m", "1h": "1H", "2h": "2H", "4h": "4H", "6h": "6H", "12h": "12H", "1d": "1D", "1w": "1W", } # Function to calculate technical indicators def calculate_technical_indicators(df): # RSI Calculation delta = df['close'].diff() gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() rs = gain / loss df['RSI'] = 100 - (100 / (1 + rs)) # MACD Calculation exp1 = df['close'].ewm(span=12, adjust=False).mean() exp2 = df['close'].ewm(span=26, adjust=False).mean() df['MACD'] = exp1 - exp2 df['Signal_Line'] = df['MACD'].ewm(span=9, adjust=False).mean() # Bollinger Bands Calculation df['MA20'] = df['close'].rolling(window=20).mean() df['BB_upper'] = df['MA20'] + 2 * df['close'].rolling(window=20).std() df['BB_lower'] = df['MA20'] - 2 * df['close'].rolling(window=20).std() return df # Function to create technical analysis charts def create_technical_charts(df): # Price and Bollinger Bands Chart fig1 = go.Figure() fig1.add_trace(go.Candlestick( x=df['timestamp'], open=df['open'], high=df['high'], low=df['low'], close=df['close'], name='Price' )) fig1.add_trace(go.Scatter(x=df['timestamp'], y=df['BB_upper'], name='Upper BB', line=dict(color='gray', dash='dash'))) fig1.add_trace(go.Scatter(x=df['timestamp'], y=df['BB_lower'], name='Lower BB', line=dict(color='gray', dash='dash'))) fig1.update_layout(title='Price and Bollinger Bands', xaxis_title='Date', yaxis_title='Price') # RSI Chart fig2 = go.Figure() fig2.add_trace(go.Scatter(x=df['timestamp'], y=df['RSI'], name='RSI')) fig2.add_hline(y=70, line_dash="dash", line_color="red") fig2.add_hline(y=30, line_dash="dash", line_color="green") fig2.update_layout(title='RSI Indicator', xaxis_title='Date', yaxis_title='RSI') # MACD Chart fig3 = go.Figure() fig3.add_trace(go.Scatter(x=df['timestamp'], y=df['MACD'], name='MACD')) fig3.add_trace(go.Scatter(x=df['timestamp'], y=df['Signal_Line'], name='Signal Line')) fig3.update_layout(title='MACD', xaxis_title='Date', yaxis_title='Value') return fig1, fig2, fig3 # Fetch available symbols from OKX API def fetch_okx_symbols(): try: resp = requests.get(OKX_TICKERS_ENDPOINT) data = resp.json().get("data", []) symbols = [item["instId"] for item in data if item.get("instType") == "SPOT"] return ["BTC-USDT"] + symbols if symbols else ["BTC-USDT"] except Exception as e: print(f"Error fetching symbols: {e}") return ["BTC-USDT"] # Fetch historical candle data from OKX API def fetch_okx_candles(symbol, timeframe="1H", total=2000): calls_needed = math.ceil(total / 300) all_data = [] for _ in range(calls_needed): params = {"instId": symbol, "bar": timeframe, "limit": 300} try: resp = requests.get(OKX_CANDLE_ENDPOINT, params=params) resp.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx) data = resp.json().get("data", []) except requests.exceptions.RequestException as e: print(f"Error fetching candles: {e}") return pd.DataFrame() except (ValueError, KeyError) as e: print(f"Error parsing candle data: {e}") return pd.DataFrame() if not data: break columns = ["ts", "o", "h", "l", "c"] df_chunk = pd.DataFrame(data, columns=columns) df_chunk.rename(columns={"ts": "timestamp", "o": "open", "h": "high", "l": "low", "c": "close"}, inplace=True) all_data.append(df_chunk) if len(data) < 300: break if not all_data: return pd.DataFrame() df_all = pd.concat(all_data) # Convert timestamps to datetime and calculate indicators df_all["timestamp"] = pd.to_datetime(df_all["timestamp"], unit="ms") numeric_cols = ["open", "high", "low", "close"] df_all[numeric_cols] = df_all[numeric_cols].astype(float) df_all = calculate_technical_indicators(df_all) return df_all # Prepare data for Prophet forecasting def prepare_data_for_prophet(df): if df.empty: return pd.DataFrame(columns=["ds", "y"]) df_prophet = df.rename(columns={"timestamp": "ds", "close": "y"}) return df_prophet[["ds", "y"]] # Perform forecasting using Prophet def prophet_forecast(df_prophet, periods=10, freq="h", daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=False, seasonality_mode="additive", changepoint_prior_scale=0.05): if df_prophet.empty: return pd.DataFrame(), "No data for Prophet." try: model = Prophet( daily_seasonality=daily_seasonality, weekly_seasonality=weekly_seasonality, yearly_seasonality=yearly_seasonality, seasonality_mode=seasonality_mode, changepoint_prior_scale=changepoint_prior_scale, ) model.fit(df_prophet) future = model.make_future_dataframe(periods=periods, freq=freq) forecast = model.predict(future) return forecast, "" except Exception as e: return pd.DataFrame(), f"Forecast error: {e}" # Wrapper function for forecasting def prophet_wrapper(df_prophet, forecast_steps, freq, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale): if len(df_prophet) < 10: return pd.DataFrame(), "Not enough data for forecasting (need >=10 rows)." full_forecast, err = prophet_forecast( df_prophet, periods=forecast_steps, freq=freq, daily_seasonality=daily_seasonality, weekly_seasonality=weekly_seasonality, yearly_seasonality=yearly_seasonality, seasonality_mode=seasonality_mode, changepoint_prior_scale=changepoint_prior_scale, ) if err: return pd.DataFrame(), err future_only = full_forecast.loc[len(df_prophet):, ["ds", "yhat", "yhat_lower", "yhat_upper"]] return future_only, "" # Create forecast plot def create_forecast_plot(forecast_df): if forecast_df.empty: return go.Figure() fig = go.Figure() fig.add_trace(go.Scatter( x=forecast_df["ds"], y=forecast_df["yhat"], mode="lines", name="Forecast", line=dict(color="blue", width=2) )) fig.add_trace(go.Scatter( x=forecast_df["ds"], y=forecast_df["yhat_lower"], fill=None, mode="lines", line=dict(width=0), showlegend=True, name="Lower Bound" )) fig.add_trace(go.Scatter( x=forecast_df["ds"], y=forecast_df["yhat_upper"], fill="tonexty", mode="lines", line=dict(width=0), name="Upper Bound" )) fig.update_layout( title="Price Forecast", xaxis_title="Time", yaxis_title="Price", hovermode="x unified", template="plotly_white", ) return fig # Function to display forecast and technical analysis charts def display_forecast(symbol, timeframe, forecast_steps, total_candles, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale): df_raw, forecast_df, error = predict( symbol=symbol, timeframe=timeframe, forecast_steps=forecast_steps, total_candles=total_candles, daily_seasonality=daily_seasonality, weekly_seasonality=weekly_seasonality, yearly_seasonality=yearly_seasonality, seasonality_mode=seasonality_mode, changepoint_prior_scale=changepoint_prior_scale ) if error: return None, None, None, None, pd.DataFrame() # Return empty dataframe for forecast_df forecast_plot = create_forecast_plot(forecast_df) tech_plot, rsi_plot, macd_plot = create_technical_charts(df_raw) # Prepare forecast data for the Dataframe output forecast_df_display = forecast_df.loc[:, ["ds", "yhat", "yhat_lower", "yhat_upper"]].copy() forecast_df_display.rename(columns={"ds": "Date", "yhat": "Forecast", "yhat_lower": "Lower Bound", "yhat_upper": "Upper Bound"}, inplace=True) return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df_display # Main prediction function def predict(symbol, timeframe, forecast_steps, total_candles, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale): okx_bar = TIMEFRAME_MAPPING.get(timeframe, "1H") df_raw = fetch_okx_candles(symbol=symbol, timeframe=okx_bar, total=total_candles) if df_raw.empty: return pd.DataFrame(), pd.DataFrame(), "No data fetched." df_prophet = prepare_data_for_prophet(df_raw) freq = "h" if "h" in timeframe.lower() else "d" future_df, err2 = prophet_wrapper( df_prophet, forecast_steps, freq, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale, ) if err2: return pd.DataFrame(), pd.DataFrame(), err2 return df_raw, future_df, "" # Main Gradio app setup def main(): symbols = fetch_okx_symbols() with gr.Blocks(theme=gr.themes.Base()) as demo: # Header with gr.Row(): gr.Markdown("# CryptoVision") # Market Selection and Forecast Parameters with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Market Selection") symbol_dd = gr.Dropdown( label="Trading Pair", choices=symbols, value="BTC-USDT" ) timeframe_dd = gr.Dropdown( label="Timeframe", choices=list(TIMEFRAME_MAPPING.keys()), value="1h" ) with gr.Column(scale=1): gr.Markdown("### Forecast Parameters") forecast_steps_slider = gr.Slider( label="Forecast Steps", minimum=1, maximum=100, value=24, step=1 ) total_candles_slider = gr.Slider( label="Historical Candles", minimum=300, maximum=3000, value=2000, step=100 ) # Advanced Settings with gr.Row(): with gr.Column(): gr.Markdown("### Advanced Settings") daily_box = gr.Checkbox(label="Daily Seasonality", value=True) weekly_box = gr.Checkbox(label="Weekly Seasonality", value=True) yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False) seasonality_mode_dd = gr.Dropdown( label="Seasonality Mode", choices=["additive", "multiplicative"], value="additive" ) changepoint_scale_slider = gr.Slider( label="Changepoint Prior Scale", minimum=0.01, maximum=1.0, step=0.01, value=0.05 ) # Generate Forecast Button forecast_btn = gr.Button("Generate Forecast", variant="primary", size="lg") # Output Plots with gr.Row(): forecast_plot = gr.Plot(label="Price Forecast") with gr.Row(): tech_plot = gr.Plot(label="Technical Analysis") rsi_plot = gr.Plot(label="RSI Indicator") with gr.Row(): macd_plot = gr.Plot(label="MACD") # Output Data Table forecast_df = gr.Dataframe( label="Forecast Data", headers=["Date", "Forecast", "Lower Bound", "Upper Bound"] ) # Button click functionality forecast_btn.click( fn=display_forecast, inputs=[ symbol_dd, timeframe_dd, forecast_steps_slider, total_candles_slider, daily_box, weekly_box, yearly_box, seasonality_mode_dd, changepoint_scale_slider, ], outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df] ) return demo if __name__ == "__main__": app = main() app.launch()