Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,14 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
-
import
|
4 |
from prophet import Prophet
|
5 |
import plotly.graph_objs as go
|
6 |
import math
|
7 |
-
import
|
8 |
-
from
|
9 |
-
from
|
|
|
10 |
|
11 |
# --- Replace with your Alpha Vantage API key ---
|
12 |
ALPHA_VANTAGE_API_KEY = "YOUR_ALPHA_VANTAGE_API_KEY" # <--- Replace with your key
|
@@ -16,6 +18,63 @@ CRYPTO_SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
|
16 |
STOCK_SYMBOLS = ["AAPL", "MSFT"]
|
17 |
INTERVAL_OPTIONS = ["1h", "60min"] # Consistent naming
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
# --- Technical Analysis Functions ---
|
20 |
def calculate_technical_indicators(df):
|
21 |
"""Calculates RSI, MACD, and Bollinger Bands."""
|
@@ -164,31 +223,60 @@ def create_forecast_plot(forecast_df):
|
|
164 |
)
|
165 |
return fig
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
# --- Main Prediction and Display Function ---
|
168 |
-
def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale):
|
169 |
"""Main function to orchestrate data fetching, analysis, and prediction."""
|
170 |
df = pd.DataFrame()
|
171 |
error_message = ""
|
172 |
-
|
173 |
# 1. Data Fetching
|
174 |
-
|
175 |
-
|
176 |
-
df = fetch_crypto_data(symbol)
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
except Exception as e:
|
183 |
-
error_message = f"Error fetching stock data: {e}"
|
184 |
-
else:
|
185 |
-
error_message = "Invalid market type selected."
|
186 |
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
189 |
|
|
|
|
|
|
|
190 |
# 2. Preprocessing & Technical Analysis
|
191 |
-
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
192 |
numeric_cols = ["open", "high", "low", "close", "volume"]
|
193 |
df[numeric_cols] = df[numeric_cols].astype(float)
|
194 |
df = calculate_technical_indicators(df)
|
@@ -209,7 +297,7 @@ def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonal
|
|
209 |
|
210 |
if prophet_error:
|
211 |
error_message = f"Prophet Error: {prophet_error}"
|
212 |
-
return None, None, None, None, None, "", error_message #Return error
|
213 |
|
214 |
forecast_plot = create_forecast_plot(forecast_df)
|
215 |
|
@@ -233,8 +321,8 @@ def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonal
|
|
233 |
# Prepare forecast data for the Dataframe output
|
234 |
forecast_df_display = forecast_df.loc[:, ["ds", "yhat", "yhat_lower", "yhat_upper"]].copy()
|
235 |
forecast_df_display.rename(columns={"ds": "Date", "yhat": "Forecast", "yhat_lower": "Lower Bound", "yhat_upper": "Upper Bound"}, inplace=True)
|
|
|
236 |
|
237 |
-
return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df_display, growth_label, error_message #Return error
|
238 |
# --- Gradio Interface ---
|
239 |
with gr.Blocks(theme=gr.themes.Base()) as demo:
|
240 |
gr.Markdown("# Market Analysis and Prediction")
|
@@ -250,6 +338,7 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
|
|
250 |
yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False)
|
251 |
seasonality_mode_dd = gr.Dropdown(label="Seasonality Mode", choices=["additive", "multiplicative"], value="additive")
|
252 |
changepoint_scale_slider = gr.Slider(label="Changepoint Prior Scale", minimum=0.01, maximum=1.0, step=0.01, value=0.05)
|
|
|
253 |
|
254 |
with gr.Column():
|
255 |
forecast_plot = gr.Plot(label="Price Forecast")
|
@@ -260,6 +349,7 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
|
|
260 |
macd_plot = gr.Plot(label="MACD")
|
261 |
forecast_df = gr.Dataframe(label="Forecast Data", headers=["Date", "Forecast", "Lower Bound", "Upper Bound"])
|
262 |
growth_label_output = gr.Label(label="Explosive Growth Prediction") # Added for prediction.
|
|
|
263 |
|
264 |
# Event Listener to update symbol dropdown based on market type
|
265 |
def update_symbol_choices(market_type):
|
@@ -283,8 +373,9 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
|
|
283 |
yearly_box,
|
284 |
seasonality_mode_dd,
|
285 |
changepoint_scale_slider,
|
|
|
286 |
],
|
287 |
-
outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df, growth_label_output]
|
288 |
)
|
289 |
|
290 |
if __name__ == "__main__":
|
|
|
1 |
+
# app.py
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
from prophet import Prophet
|
6 |
import plotly.graph_objs as go
|
7 |
import math
|
8 |
+
import requests # For API calls
|
9 |
+
from sklearn.ensemble import RandomForestClassifier # For the model
|
10 |
+
from textblob import TextBlob # For sentiment analysis (optional)
|
11 |
+
|
12 |
|
13 |
# --- Replace with your Alpha Vantage API key ---
|
14 |
ALPHA_VANTAGE_API_KEY = "YOUR_ALPHA_VANTAGE_API_KEY" # <--- Replace with your key
|
|
|
18 |
STOCK_SYMBOLS = ["AAPL", "MSFT"]
|
19 |
INTERVAL_OPTIONS = ["1h", "60min"] # Consistent naming
|
20 |
|
21 |
+
# --- Data Fetching Functions ---
|
22 |
+
def fetch_crypto_data(symbol, interval="1h", limit=100):
|
23 |
+
"""Fetch crypto market data from Binance."""
|
24 |
+
try:
|
25 |
+
url = f"https://api.binance.com/api/v3/klines"
|
26 |
+
params = {"symbol": symbol, "interval": interval, "limit": limit}
|
27 |
+
response = requests.get(url, params=params)
|
28 |
+
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
|
29 |
+
data = response.json()
|
30 |
+
df = pd.DataFrame(data, columns=["timestamp", "open", "high", "low", "close", "volume"])
|
31 |
+
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
|
32 |
+
df[["open", "high", "low", "close", "volume"]] = df[["open", "high", "low", "close", "volume"]].astype(float)
|
33 |
+
return df.dropna()
|
34 |
+
except requests.exceptions.RequestException as e:
|
35 |
+
raise Exception(f"Error fetching crypto data: {e}")
|
36 |
+
except (ValueError, KeyError) as e:
|
37 |
+
raise Exception(f"Error parsing crypto data: {e}")
|
38 |
+
|
39 |
+
|
40 |
+
def fetch_stock_data(symbol, interval="60min"):
|
41 |
+
"""Fetch stock market data from Alpha Vantage."""
|
42 |
+
try:
|
43 |
+
url = f"https://www.alphavantage.co/query"
|
44 |
+
params = {"function": "TIME_SERIES_INTRADAY", "symbol": symbol, "interval": interval,
|
45 |
+
"apikey": ALPHA_VANTAGE_API_KEY}
|
46 |
+
response = requests.get(url, params=params)
|
47 |
+
response.raise_for_status() # Raise HTTPError for bad responses
|
48 |
+
data = response.json()
|
49 |
+
if "Time Series (60min)" in data: # Check the JSON for the correct key
|
50 |
+
time_series_data = data["Time Series (60min)"]
|
51 |
+
df = pd.DataFrame(time_series_data).T.reset_index()
|
52 |
+
df.columns = ["timestamp", "open", "high", "low", "close", "volume"] # Standardize
|
53 |
+
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
54 |
+
df[["open", "high", "low", "close", "volume"]] = df[["open", "high", "low", "close", "volume"]].astype(float)
|
55 |
+
return df.dropna()
|
56 |
+
else:
|
57 |
+
raise Exception(f"Error: Could not retrieve stock data. Check API Key, Symbol, and Interval. Response: {data}")
|
58 |
+
|
59 |
+
except requests.exceptions.RequestException as e:
|
60 |
+
raise Exception(f"Error fetching stock data: {e}")
|
61 |
+
except (ValueError, KeyError) as e:
|
62 |
+
raise Exception(f"Error parsing stock data: {e}")
|
63 |
+
|
64 |
+
def fetch_sentiment_data(keyword): # Placeholder - replace with a real sentiment analysis method
|
65 |
+
"""Analyze sentiment from social media (placeholder)."""
|
66 |
+
try:
|
67 |
+
tweets = [
|
68 |
+
f"{keyword} is going to moon!",
|
69 |
+
f"I hate {keyword}, it's trash!",
|
70 |
+
f"{keyword} is amazing!"
|
71 |
+
]
|
72 |
+
sentiments = [TextBlob(tweet).sentiment.polarity for tweet in tweets]
|
73 |
+
return sum(sentiments) / len(sentiments) / len(sentiments) if sentiments else 0 # Avoid ZeroDivisionError
|
74 |
+
except Exception as e:
|
75 |
+
print(f"Sentiment analysis error: {e}")
|
76 |
+
return 0
|
77 |
+
|
78 |
# --- Technical Analysis Functions ---
|
79 |
def calculate_technical_indicators(df):
|
80 |
"""Calculates RSI, MACD, and Bollinger Bands."""
|
|
|
223 |
)
|
224 |
return fig
|
225 |
|
226 |
+
# --- Model Training and Prediction ---
|
227 |
+
model = RandomForestClassifier() # Moved here
|
228 |
+
|
229 |
+
def train_model(df):
|
230 |
+
"""Train the AI model."""
|
231 |
+
if df.empty:
|
232 |
+
return # Or raise an exception, or return a default model.
|
233 |
+
df["target"] = (df["close"].pct_change() > 0.05).astype(int) # Target: 1 if price increased by >5%
|
234 |
+
features = df[["close", "volume"]].dropna()
|
235 |
+
target = df["target"].dropna()
|
236 |
+
if not features.empty and not target.empty: #check data is available
|
237 |
+
model.fit(features, target)
|
238 |
+
else:
|
239 |
+
print("Not enough data for model training.")
|
240 |
+
|
241 |
+
def predict_growth(latest_data):
|
242 |
+
"""Predict explosive growth."""
|
243 |
+
if not hasattr(model, 'estimators_') or len(model.estimators_) == 0: # Check if model is trained
|
244 |
+
return [0] # Or return an error message, or a default value
|
245 |
+
|
246 |
+
try:
|
247 |
+
prediction = model.predict(latest_data.reshape(1, -1))
|
248 |
+
return prediction
|
249 |
+
except Exception as e:
|
250 |
+
print(f"Prediction error: {e}")
|
251 |
+
return [0]
|
252 |
+
|
253 |
# --- Main Prediction and Display Function ---
|
254 |
+
def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale, sentiment_keyword=""):
|
255 |
"""Main function to orchestrate data fetching, analysis, and prediction."""
|
256 |
df = pd.DataFrame()
|
257 |
error_message = ""
|
258 |
+
sentiment_score = 0 # Initialize sentiment score
|
259 |
# 1. Data Fetching
|
260 |
+
try:
|
261 |
+
if market_type == "Crypto":
|
262 |
+
df = fetch_crypto_data(symbol, interval=interval)
|
263 |
+
elif market_type == "Stock":
|
264 |
+
df = fetch_stock_data(symbol, interval=interval)
|
265 |
+
else:
|
266 |
+
error_message = "Invalid market type selected."
|
267 |
+
return None, None, None, None, None, "", error_message, 0 # Also return sentiment
|
|
|
|
|
|
|
|
|
268 |
|
269 |
+
if sentiment_keyword: # If a keyword for sentiment is entered:
|
270 |
+
sentiment_score = fetch_sentiment_data(sentiment_keyword)
|
271 |
+
except Exception as e:
|
272 |
+
error_message = f"Data Fetching Error: {e}"
|
273 |
+
return None, None, None, None, None, "", error_message, 0 #Return error + sentiment
|
274 |
|
275 |
+
if df.empty:
|
276 |
+
error_message = "No data fetched."
|
277 |
+
return None, None, None, None, None, "", error_message, 0 # Return empty + sentiment
|
278 |
# 2. Preprocessing & Technical Analysis
|
279 |
+
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
280 |
numeric_cols = ["open", "high", "low", "close", "volume"]
|
281 |
df[numeric_cols] = df[numeric_cols].astype(float)
|
282 |
df = calculate_technical_indicators(df)
|
|
|
297 |
|
298 |
if prophet_error:
|
299 |
error_message = f"Prophet Error: {prophet_error}"
|
300 |
+
return None, None, None, None, None, "", error_message, sentiment_score #Return prophet error
|
301 |
|
302 |
forecast_plot = create_forecast_plot(forecast_df)
|
303 |
|
|
|
321 |
# Prepare forecast data for the Dataframe output
|
322 |
forecast_df_display = forecast_df.loc[:, ["ds", "yhat", "yhat_lower", "yhat_upper"]].copy()
|
323 |
forecast_df_display.rename(columns={"ds": "Date", "yhat": "Forecast", "yhat_lower": "Lower Bound", "yhat_upper": "Upper Bound"}, inplace=True)
|
324 |
+
return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df_display, growth_label, error_message, sentiment_score
|
325 |
|
|
|
326 |
# --- Gradio Interface ---
|
327 |
with gr.Blocks(theme=gr.themes.Base()) as demo:
|
328 |
gr.Markdown("# Market Analysis and Prediction")
|
|
|
338 |
yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False)
|
339 |
seasonality_mode_dd = gr.Dropdown(label="Seasonality Mode", choices=["additive", "multiplicative"], value="additive")
|
340 |
changepoint_scale_slider = gr.Slider(label="Changepoint Prior Scale", minimum=0.01, maximum=1.0, step=0.01, value=0.05)
|
341 |
+
sentiment_keyword_txt = gr.Textbox(label="Sentiment Keyword (optional)") #Add Sentiment input
|
342 |
|
343 |
with gr.Column():
|
344 |
forecast_plot = gr.Plot(label="Price Forecast")
|
|
|
349 |
macd_plot = gr.Plot(label="MACD")
|
350 |
forecast_df = gr.Dataframe(label="Forecast Data", headers=["Date", "Forecast", "Lower Bound", "Upper Bound"])
|
351 |
growth_label_output = gr.Label(label="Explosive Growth Prediction") # Added for prediction.
|
352 |
+
sentiment_label_output = gr.Number(label="Sentiment Score") # Added for sentiment output
|
353 |
|
354 |
# Event Listener to update symbol dropdown based on market type
|
355 |
def update_symbol_choices(market_type):
|
|
|
373 |
yearly_box,
|
374 |
seasonality_mode_dd,
|
375 |
changepoint_scale_slider,
|
376 |
+
sentiment_keyword_txt, # Add sentiment keyword to the input
|
377 |
],
|
378 |
+
outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df, growth_label_output, gr.Label(label="Error Message"), sentiment_label_output] # Add sentiment score to the output
|
379 |
)
|
380 |
|
381 |
if __name__ == "__main__":
|