MacDash commited on
Commit
c63bfc4
·
verified ·
1 Parent(s): 87def00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -24
app.py CHANGED
@@ -1,12 +1,14 @@
 
1
  import gradio as gr
2
  import pandas as pd
3
- import requests
4
  from prophet import Prophet
5
  import plotly.graph_objs as go
6
  import math
7
- import numpy as np
8
- from data_fetcher import fetch_crypto_data, fetch_stock_data, fetch_sentiment_data # Import the data fetcher module
9
- from src.model import train_model, predict_growth # Import your model functions
 
10
 
11
  # --- Replace with your Alpha Vantage API key ---
12
  ALPHA_VANTAGE_API_KEY = "YOUR_ALPHA_VANTAGE_API_KEY" # <--- Replace with your key
@@ -16,6 +18,63 @@ CRYPTO_SYMBOLS = ["BTCUSDT", "ETHUSDT"]
16
  STOCK_SYMBOLS = ["AAPL", "MSFT"]
17
  INTERVAL_OPTIONS = ["1h", "60min"] # Consistent naming
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # --- Technical Analysis Functions ---
20
  def calculate_technical_indicators(df):
21
  """Calculates RSI, MACD, and Bollinger Bands."""
@@ -164,31 +223,60 @@ def create_forecast_plot(forecast_df):
164
  )
165
  return fig
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  # --- Main Prediction and Display Function ---
168
- def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale):
169
  """Main function to orchestrate data fetching, analysis, and prediction."""
170
  df = pd.DataFrame()
171
  error_message = ""
172
-
173
  # 1. Data Fetching
174
- if market_type == "Crypto":
175
- try:
176
- df = fetch_crypto_data(symbol)
177
- except Exception as e:
178
- error_message = f"Error fetching crypto data: {e}"
179
- elif market_type == "Stock":
180
- try:
181
- df = fetch_stock_data(symbol)
182
- except Exception as e:
183
- error_message = f"Error fetching stock data: {e}"
184
- else:
185
- error_message = "Invalid market type selected."
186
 
187
- if df.empty:
188
- return None, None, None, None, None, "", error_message # Correctly pass the error message
 
 
 
189
 
 
 
 
190
  # 2. Preprocessing & Technical Analysis
191
- df["timestamp"] = pd.to_datetime(df["timestamp"]) # No unit arg as it's handled in fetcher
192
  numeric_cols = ["open", "high", "low", "close", "volume"]
193
  df[numeric_cols] = df[numeric_cols].astype(float)
194
  df = calculate_technical_indicators(df)
@@ -209,7 +297,7 @@ def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonal
209
 
210
  if prophet_error:
211
  error_message = f"Prophet Error: {prophet_error}"
212
- return None, None, None, None, None, "", error_message #Return error
213
 
214
  forecast_plot = create_forecast_plot(forecast_df)
215
 
@@ -233,8 +321,8 @@ def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonal
233
  # Prepare forecast data for the Dataframe output
234
  forecast_df_display = forecast_df.loc[:, ["ds", "yhat", "yhat_lower", "yhat_upper"]].copy()
235
  forecast_df_display.rename(columns={"ds": "Date", "yhat": "Forecast", "yhat_lower": "Lower Bound", "yhat_upper": "Upper Bound"}, inplace=True)
 
236
 
237
- return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df_display, growth_label, error_message #Return error
238
  # --- Gradio Interface ---
239
  with gr.Blocks(theme=gr.themes.Base()) as demo:
240
  gr.Markdown("# Market Analysis and Prediction")
@@ -250,6 +338,7 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
250
  yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False)
251
  seasonality_mode_dd = gr.Dropdown(label="Seasonality Mode", choices=["additive", "multiplicative"], value="additive")
252
  changepoint_scale_slider = gr.Slider(label="Changepoint Prior Scale", minimum=0.01, maximum=1.0, step=0.01, value=0.05)
 
253
 
254
  with gr.Column():
255
  forecast_plot = gr.Plot(label="Price Forecast")
@@ -260,6 +349,7 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
260
  macd_plot = gr.Plot(label="MACD")
261
  forecast_df = gr.Dataframe(label="Forecast Data", headers=["Date", "Forecast", "Lower Bound", "Upper Bound"])
262
  growth_label_output = gr.Label(label="Explosive Growth Prediction") # Added for prediction.
 
263
 
264
  # Event Listener to update symbol dropdown based on market type
265
  def update_symbol_choices(market_type):
@@ -283,8 +373,9 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
283
  yearly_box,
284
  seasonality_mode_dd,
285
  changepoint_scale_slider,
 
286
  ],
287
- outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df, growth_label_output]
288
  )
289
 
290
  if __name__ == "__main__":
 
1
+ # app.py
2
  import gradio as gr
3
  import pandas as pd
4
+ import numpy as np
5
  from prophet import Prophet
6
  import plotly.graph_objs as go
7
  import math
8
+ import requests # For API calls
9
+ from sklearn.ensemble import RandomForestClassifier # For the model
10
+ from textblob import TextBlob # For sentiment analysis (optional)
11
+
12
 
13
  # --- Replace with your Alpha Vantage API key ---
14
  ALPHA_VANTAGE_API_KEY = "YOUR_ALPHA_VANTAGE_API_KEY" # <--- Replace with your key
 
18
  STOCK_SYMBOLS = ["AAPL", "MSFT"]
19
  INTERVAL_OPTIONS = ["1h", "60min"] # Consistent naming
20
 
21
+ # --- Data Fetching Functions ---
22
+ def fetch_crypto_data(symbol, interval="1h", limit=100):
23
+ """Fetch crypto market data from Binance."""
24
+ try:
25
+ url = f"https://api.binance.com/api/v3/klines"
26
+ params = {"symbol": symbol, "interval": interval, "limit": limit}
27
+ response = requests.get(url, params=params)
28
+ response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
29
+ data = response.json()
30
+ df = pd.DataFrame(data, columns=["timestamp", "open", "high", "low", "close", "volume"])
31
+ df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
32
+ df[["open", "high", "low", "close", "volume"]] = df[["open", "high", "low", "close", "volume"]].astype(float)
33
+ return df.dropna()
34
+ except requests.exceptions.RequestException as e:
35
+ raise Exception(f"Error fetching crypto data: {e}")
36
+ except (ValueError, KeyError) as e:
37
+ raise Exception(f"Error parsing crypto data: {e}")
38
+
39
+
40
+ def fetch_stock_data(symbol, interval="60min"):
41
+ """Fetch stock market data from Alpha Vantage."""
42
+ try:
43
+ url = f"https://www.alphavantage.co/query"
44
+ params = {"function": "TIME_SERIES_INTRADAY", "symbol": symbol, "interval": interval,
45
+ "apikey": ALPHA_VANTAGE_API_KEY}
46
+ response = requests.get(url, params=params)
47
+ response.raise_for_status() # Raise HTTPError for bad responses
48
+ data = response.json()
49
+ if "Time Series (60min)" in data: # Check the JSON for the correct key
50
+ time_series_data = data["Time Series (60min)"]
51
+ df = pd.DataFrame(time_series_data).T.reset_index()
52
+ df.columns = ["timestamp", "open", "high", "low", "close", "volume"] # Standardize
53
+ df["timestamp"] = pd.to_datetime(df["timestamp"])
54
+ df[["open", "high", "low", "close", "volume"]] = df[["open", "high", "low", "close", "volume"]].astype(float)
55
+ return df.dropna()
56
+ else:
57
+ raise Exception(f"Error: Could not retrieve stock data. Check API Key, Symbol, and Interval. Response: {data}")
58
+
59
+ except requests.exceptions.RequestException as e:
60
+ raise Exception(f"Error fetching stock data: {e}")
61
+ except (ValueError, KeyError) as e:
62
+ raise Exception(f"Error parsing stock data: {e}")
63
+
64
+ def fetch_sentiment_data(keyword): # Placeholder - replace with a real sentiment analysis method
65
+ """Analyze sentiment from social media (placeholder)."""
66
+ try:
67
+ tweets = [
68
+ f"{keyword} is going to moon!",
69
+ f"I hate {keyword}, it's trash!",
70
+ f"{keyword} is amazing!"
71
+ ]
72
+ sentiments = [TextBlob(tweet).sentiment.polarity for tweet in tweets]
73
+ return sum(sentiments) / len(sentiments) / len(sentiments) if sentiments else 0 # Avoid ZeroDivisionError
74
+ except Exception as e:
75
+ print(f"Sentiment analysis error: {e}")
76
+ return 0
77
+
78
  # --- Technical Analysis Functions ---
79
  def calculate_technical_indicators(df):
80
  """Calculates RSI, MACD, and Bollinger Bands."""
 
223
  )
224
  return fig
225
 
226
+ # --- Model Training and Prediction ---
227
+ model = RandomForestClassifier() # Moved here
228
+
229
+ def train_model(df):
230
+ """Train the AI model."""
231
+ if df.empty:
232
+ return # Or raise an exception, or return a default model.
233
+ df["target"] = (df["close"].pct_change() > 0.05).astype(int) # Target: 1 if price increased by >5%
234
+ features = df[["close", "volume"]].dropna()
235
+ target = df["target"].dropna()
236
+ if not features.empty and not target.empty: #check data is available
237
+ model.fit(features, target)
238
+ else:
239
+ print("Not enough data for model training.")
240
+
241
+ def predict_growth(latest_data):
242
+ """Predict explosive growth."""
243
+ if not hasattr(model, 'estimators_') or len(model.estimators_) == 0: # Check if model is trained
244
+ return [0] # Or return an error message, or a default value
245
+
246
+ try:
247
+ prediction = model.predict(latest_data.reshape(1, -1))
248
+ return prediction
249
+ except Exception as e:
250
+ print(f"Prediction error: {e}")
251
+ return [0]
252
+
253
  # --- Main Prediction and Display Function ---
254
+ def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale, sentiment_keyword=""):
255
  """Main function to orchestrate data fetching, analysis, and prediction."""
256
  df = pd.DataFrame()
257
  error_message = ""
258
+ sentiment_score = 0 # Initialize sentiment score
259
  # 1. Data Fetching
260
+ try:
261
+ if market_type == "Crypto":
262
+ df = fetch_crypto_data(symbol, interval=interval)
263
+ elif market_type == "Stock":
264
+ df = fetch_stock_data(symbol, interval=interval)
265
+ else:
266
+ error_message = "Invalid market type selected."
267
+ return None, None, None, None, None, "", error_message, 0 # Also return sentiment
 
 
 
 
268
 
269
+ if sentiment_keyword: # If a keyword for sentiment is entered:
270
+ sentiment_score = fetch_sentiment_data(sentiment_keyword)
271
+ except Exception as e:
272
+ error_message = f"Data Fetching Error: {e}"
273
+ return None, None, None, None, None, "", error_message, 0 #Return error + sentiment
274
 
275
+ if df.empty:
276
+ error_message = "No data fetched."
277
+ return None, None, None, None, None, "", error_message, 0 # Return empty + sentiment
278
  # 2. Preprocessing & Technical Analysis
279
+ df["timestamp"] = pd.to_datetime(df["timestamp"])
280
  numeric_cols = ["open", "high", "low", "close", "volume"]
281
  df[numeric_cols] = df[numeric_cols].astype(float)
282
  df = calculate_technical_indicators(df)
 
297
 
298
  if prophet_error:
299
  error_message = f"Prophet Error: {prophet_error}"
300
+ return None, None, None, None, None, "", error_message, sentiment_score #Return prophet error
301
 
302
  forecast_plot = create_forecast_plot(forecast_df)
303
 
 
321
  # Prepare forecast data for the Dataframe output
322
  forecast_df_display = forecast_df.loc[:, ["ds", "yhat", "yhat_lower", "yhat_upper"]].copy()
323
  forecast_df_display.rename(columns={"ds": "Date", "yhat": "Forecast", "yhat_lower": "Lower Bound", "yhat_upper": "Upper Bound"}, inplace=True)
324
+ return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df_display, growth_label, error_message, sentiment_score
325
 
 
326
  # --- Gradio Interface ---
327
  with gr.Blocks(theme=gr.themes.Base()) as demo:
328
  gr.Markdown("# Market Analysis and Prediction")
 
338
  yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False)
339
  seasonality_mode_dd = gr.Dropdown(label="Seasonality Mode", choices=["additive", "multiplicative"], value="additive")
340
  changepoint_scale_slider = gr.Slider(label="Changepoint Prior Scale", minimum=0.01, maximum=1.0, step=0.01, value=0.05)
341
+ sentiment_keyword_txt = gr.Textbox(label="Sentiment Keyword (optional)") #Add Sentiment input
342
 
343
  with gr.Column():
344
  forecast_plot = gr.Plot(label="Price Forecast")
 
349
  macd_plot = gr.Plot(label="MACD")
350
  forecast_df = gr.Dataframe(label="Forecast Data", headers=["Date", "Forecast", "Lower Bound", "Upper Bound"])
351
  growth_label_output = gr.Label(label="Explosive Growth Prediction") # Added for prediction.
352
+ sentiment_label_output = gr.Number(label="Sentiment Score") # Added for sentiment output
353
 
354
  # Event Listener to update symbol dropdown based on market type
355
  def update_symbol_choices(market_type):
 
373
  yearly_box,
374
  seasonality_mode_dd,
375
  changepoint_scale_slider,
376
+ sentiment_keyword_txt, # Add sentiment keyword to the input
377
  ],
378
+ outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df, growth_label_output, gr.Label(label="Error Message"), sentiment_label_output] # Add sentiment score to the output
379
  )
380
 
381
  if __name__ == "__main__":