import gradio as gr import pandas as pd import requests from prophet import Prophet import plotly.graph_objs as go import math import numpy as np from data_fetcher import fetch_crypto_data, fetch_stock_data, fetch_sentiment_data # Import the data fetcher module from src.model import train_model, predict_growth # Import your model functions # --- Replace with your Alpha Vantage API key --- ALPHA_VANTAGE_API_KEY = "YOUR_ALPHA_VANTAGE_API_KEY" # <--- Replace with your key # --- Constants --- CRYPTO_SYMBOLS = ["BTCUSDT", "ETHUSDT"] STOCK_SYMBOLS = ["AAPL", "MSFT"] INTERVAL_OPTIONS = ["1h", "60min"] # Consistent naming # --- Technical Analysis Functions --- def calculate_technical_indicators(df): """Calculates RSI, MACD, and Bollinger Bands.""" if df.empty: return df # RSI Calculation delta = df['close'].diff() gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() rs = gain / loss df['RSI'] = 100 - (100 / (1 + rs)) # MACD Calculation exp1 = df['close'].ewm(span=12, adjust=False).mean() exp2 = df['close'].ewm(span=26, adjust=False).mean() df['MACD'] = exp1 - exp2 df['Signal_Line'] = df['MACD'].ewm(span=9, adjust=False).mean() # Bollinger Bands Calculation df['MA20'] = df['close'].rolling(window=20).mean() df['BB_upper'] = df['MA20'] + 2 * df['close'].rolling(window=20).std() df['BB_lower'] = df['MA20'] - 2 * df['close'].rolling(window=20).std() return df def create_technical_charts(df): """Creates technical analysis charts (Price, RSI, MACD).""" if df.empty: return None, None, None fig1 = go.Figure() fig1.add_trace(go.Candlestick( x=df['timestamp'], open=df['open'], high=df['high'], low=df['low'], close=df['close'], name='Price' )) fig1.add_trace(go.Scatter(x=df['timestamp'], y=df['BB_upper'], name='Upper BB', line=dict(color='gray', dash='dash'))) fig1.add_trace(go.Scatter(x=df['timestamp'], y=df['BB_lower'], name='Lower BB', line=dict(color='gray', dash='dash'))) fig1.update_layout(title='Price and Bollinger Bands', xaxis_title='Date', yaxis_title='Price') fig2 = go.Figure() fig2.add_trace(go.Scatter(x=df['timestamp'], y=df['RSI'], name='RSI')) fig2.add_hline(y=70, line_dash="dash", line_color="red") fig2.add_hline(y=30, line_dash="dash", line_color="green") fig2.update_layout(title='RSI Indicator', xaxis_title='Date', yaxis_title='RSI') fig3 = go.Figure() fig3.add_trace(go.Scatter(x=df['timestamp'], y=df['MACD'], name='MACD')) fig3.add_trace(go.Scatter(x=df['timestamp'], y=df['Signal_Line'], name='Signal Line')) fig3.update_layout(title='MACD', xaxis_title='Date', yaxis_title='Value') return fig1, fig2, fig3 # --- Prophet Forecasting Functions --- def prepare_data_for_prophet(df): """Prepares data for Prophet.""" if df.empty: return pd.DataFrame(columns=["ds", "y"]) df_prophet = df.rename(columns={"timestamp": "ds", "close": "y"}) return df_prophet[["ds", "y"]] def prophet_forecast(df_prophet, periods=10, freq="h", daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=False, seasonality_mode="additive", changepoint_prior_scale=0.05): """Performs Prophet forecasting.""" if df_prophet.empty: return pd.DataFrame(), "No data for Prophet." try: model = Prophet( daily_seasonality=daily_seasonality, weekly_seasonality=weekly_seasonality, yearly_seasonality=yearly_seasonality, seasonality_mode=seasonality_mode, changepoint_prior_scale=changepoint_prior_scale, ) model.fit(df_prophet) future = model.make_future_dataframe(periods=periods, freq=freq) forecast = model.predict(future) return forecast, "" except Exception as e: return pd.DataFrame(), f"Forecast error: {e}" def prophet_wrapper(df_prophet, forecast_steps, freq, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale): """Wrapper for Prophet forecasting.""" if len(df_prophet) < 10: return pd.DataFrame(), "Not enough data for forecasting (need >=10 rows)." full_forecast, err = prophet_forecast( df_prophet, periods=forecast_steps, freq=freq, daily_seasonality=daily_seasonality, weekly_seasonality=weekly_seasonality, yearly_seasonality=yearly_seasonality, seasonality_mode=seasonality_mode, changepoint_prior_scale=changepoint_prior_scale, ) if err: return pd.DataFrame(), err future_only = full_forecast.loc[len(df_prophet):, ["ds", "yhat", "yhat_lower", "yhat_upper"]] return future_only, "" def create_forecast_plot(forecast_df): """Creates the forecast plot.""" if forecast_df.empty: return go.Figure() fig = go.Figure() fig.add_trace(go.Scatter( x=forecast_df["ds"], y=forecast_df["yhat"], mode="lines", name="Forecast", line=dict(color="blue", width=2) )) fig.add_trace(go.Scatter( x=forecast_df["ds"], y=forecast_df["yhat_lower"], fill=None, mode="lines", line=dict(width=0), showlegend=True, name="Lower Bound" )) fig.add_trace(go.Scatter( x=forecast_df["ds"], y=forecast_df["yhat_upper"], fill="tonexty", mode="lines", line=dict(width=0), name="Upper Bound" )) fig.update_layout( title="Price Forecast", xaxis_title="Time", yaxis_title="Price", hovermode="x unified", template="plotly_white", ) return fig # --- Main Prediction and Display Function --- def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale): """Main function to orchestrate data fetching, analysis, and prediction.""" df = pd.DataFrame() error_message = "" # 1. Data Fetching if market_type == "Crypto": try: df = fetch_crypto_data(symbol) except Exception as e: error_message = f"Error fetching crypto data: {e}" elif market_type == "Stock": try: df = fetch_stock_data(symbol) except Exception as e: error_message = f"Error fetching stock data: {e}" else: error_message = "Invalid market type selected." if df.empty: return None, None, None, None, None, "", error_message # Correctly pass the error message # 2. Preprocessing & Technical Analysis df["timestamp"] = pd.to_datetime(df["timestamp"]) # No unit arg as it's handled in fetcher numeric_cols = ["open", "high", "low", "close", "volume"] df[numeric_cols] = df[numeric_cols].astype(float) df = calculate_technical_indicators(df) # 3. Prophet Forecasting df_prophet = prepare_data_for_prophet(df) freq = "h" if interval == "1h" or interval == "60min" else "d" #dynamic freq forecast_df, prophet_error = prophet_wrapper( df_prophet, forecast_steps, freq, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale, ) if prophet_error: error_message = f"Prophet Error: {prophet_error}" return None, None, None, None, None, "", error_message #Return error forecast_plot = create_forecast_plot(forecast_df) # 4. Create the Charts tech_plot, rsi_plot, macd_plot = create_technical_charts(df) # 5. Model Training and Prediction (simplified) try: train_model(df.copy()) # Train on a copy to avoid modifying original df. if not df.empty: #Check if dataframe is empty or not. latest_data = df[["close", "volume"]].iloc[-1].values # Get the last row for prediction. growth_prediction = predict_growth(latest_data) growth_label = "Yes" if growth_prediction[0] == 1 else "No" else: growth_label = "N/A: Insufficient Data" # If there is no data to predict the growth. except Exception as e: error_message = f"Model Error: {e}" growth_label = "N/A" # Prepare forecast data for the Dataframe output forecast_df_display = forecast_df.loc[:, ["ds", "yhat", "yhat_lower", "yhat_upper"]].copy() forecast_df_display.rename(columns={"ds": "Date", "yhat": "Forecast", "yhat_lower": "Lower Bound", "yhat_upper": "Upper Bound"}, inplace=True) return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df_display, growth_label, error_message #Return error # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Base()) as demo: gr.Markdown("# Market Analysis and Prediction") with gr.Row(): with gr.Column(): market_type_dd = gr.Radio(label="Market Type", choices=["Crypto", "Stock"], value="Crypto") symbol_dd = gr.Dropdown(label="Symbol", choices=CRYPTO_SYMBOLS, value="BTCUSDT") # Start with Crypto interval_dd = gr.Dropdown(label="Interval", choices=INTERVAL_OPTIONS, value="1h") forecast_steps_slider = gr.Slider(label="Forecast Steps", minimum=1, maximum=100, value=24, step=1) daily_box = gr.Checkbox(label="Daily Seasonality", value=True) weekly_box = gr.Checkbox(label="Weekly Seasonality", value=True) yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False) seasonality_mode_dd = gr.Dropdown(label="Seasonality Mode", choices=["additive", "multiplicative"], value="additive") changepoint_scale_slider = gr.Slider(label="Changepoint Prior Scale", minimum=0.01, maximum=1.0, step=0.01, value=0.05) with gr.Column(): forecast_plot = gr.Plot(label="Price Forecast") with gr.Row(): tech_plot = gr.Plot(label="Technical Analysis") rsi_plot = gr.Plot(label="RSI Indicator") with gr.Row(): macd_plot = gr.Plot(label="MACD") forecast_df = gr.Dataframe(label="Forecast Data", headers=["Date", "Forecast", "Lower Bound", "Upper Bound"]) growth_label_output = gr.Label(label="Explosive Growth Prediction") # Added for prediction. # Event Listener to update symbol dropdown based on market type def update_symbol_choices(market_type): if market_type == "Crypto": return gr.Dropdown(choices=CRYPTO_SYMBOLS, value="BTCUSDT") elif market_type == "Stock": return gr.Dropdown(choices=STOCK_SYMBOLS, value="AAPL") # Default to AAPL for stock return gr.Dropdown(choices=[], value=None) # Shouldn't happen, but safety check market_type_dd.change(fn=update_symbol_choices, inputs=[market_type_dd], outputs=[symbol_dd]) analyze_button = gr.Button("Analyze Market", variant="primary") analyze_button.click( fn=analyze_market, inputs=[ market_type_dd, symbol_dd, interval_dd, forecast_steps_slider, daily_box, weekly_box, yearly_box, seasonality_mode_dd, changepoint_scale_slider, ], outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df, growth_label_output] ) if __name__ == "__main__": demo.launch()