MacDash commited on
Commit
42e86be
·
verified ·
1 Parent(s): be359fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +356 -60
app.py CHANGED
@@ -1,64 +1,360 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ from prophet import Prophet
5
+ import plotly.graph_objs as go
6
+ import requests
7
+ from sklearn.ensemble import RandomForestClassifier
8
+ from textblob import TextBlob
9
+ import yfinance as yf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # --- Constants ---
12
+ CRYPTO_SYMBOLS = ["BTC-USD", "ETH-USD", "LTC-USD", "XRP-USD"]
13
+ STOCK_SYMBOLS = ["AAPL", "MSFT", "GOOGL", "AMZN"]
14
+ INTERVAL_OPTIONS = ["1h", "1d", "1wk"]
15
+ DEFAULT_FORECAST_STEPS = 24
16
+ DEFAULT_DAILY_SEASONALITY = True
17
+ DEFAULT_WEEKLY_SEASONALITY = True
18
+ DEFAULT_YEARLY_SEASONALITY = False
19
+ DEFAULT_SEASONALITY_MODE = "additive"
20
+ DEFAULT_CHANGEPOINT_PRIOR_SCALE = 0.05
21
+ RANDOM_FOREST_PARAMS = {
22
+ "n_estimators": 100,
23
+ "max_depth": 10,
24
+ "random_state": 42
25
+ }
26
+
27
+ # --- Data Fetching Functions ---
28
+ def fetch_crypto_data(symbol, interval="1h", limit=100):
29
+ try:
30
+ ticker = yf.Ticker(symbol)
31
+ if interval == "1h":
32
+ period = "1d"
33
+ df = ticker.history(period=period, interval="1h")
34
+ elif interval == "1d":
35
+ df = ticker.history(period="1y", interval=interval)
36
+ elif interval == "1wk":
37
+ df = ticker.history(period="5y", interval=interval)
38
+ else:
39
+ raise ValueError("Invalid interval for yfinance.")
40
+ if df.empty:
41
+ raise Exception("No data returned from yfinance.")
42
+ df.reset_index(inplace=True)
43
+ df.rename(columns={"Datetime": "timestamp", "Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"}, inplace=True)
44
+ df = df[["timestamp", "open", "high", "low", "close", "volume"]]
45
+ return df.dropna()
46
+ except Exception as e:
47
+ raise Exception(f"Error fetching crypto data from yfinance: {e}")
48
+
49
+ def fetch_stock_data(symbol, interval="1d"):
50
+ try:
51
+ ticker = yf.Ticker(symbol)
52
+ df = ticker.history(period="1y", interval=interval)
53
+ if df.empty:
54
+ raise Exception("No data returned from yfinance.")
55
+ df.reset_index(inplace=True)
56
+ df.rename(columns={"Date": "timestamp", "Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"}, inplace=True)
57
+ df = df[["timestamp", "open", "high", "low", "close", "volume"]]
58
+ return df.dropna()
59
+ except Exception as e:
60
+ raise Exception(f"Error fetching stock data from yfinance: {e}")
61
+
62
+ def fetch_sentiment_data(keyword):
63
+ try:
64
+ tweets = [
65
+ f"{keyword} is going to moon!",
66
+ f"I hate {keyword}, it's trash!",
67
+ f"{keyword} is amazing!"
68
+ ]
69
+ sentiments = [TextBlob(tweet).sentiment.polarity for tweet in tweets]
70
+ return sum(sentiments) / len(sentiments) if sentiments else 0
71
+ except Exception as e:
72
+ print(f"Sentiment analysis error: {e}")
73
+ return 0
74
+
75
+ # --- Technical Analysis Functions ---
76
+ def calculate_technical_indicators(df):
77
+ if df.empty:
78
+ return df
79
+
80
+ delta = df['close'].diff()
81
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
82
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
83
+ rs = gain / loss
84
+ df['RSI'] = 100 - (100 / (1 + rs))
85
+
86
+ exp1 = df['close'].ewm(span=12, adjust=False).mean()
87
+ exp2 = df['close'].ewm(span=26, adjust=False).mean()
88
+ df['MACD'] = exp1 - exp2
89
+ df['Signal_Line'] = df['MACD'].ewm(span=9, adjust=False).mean()
90
+
91
+ df['MA20'] = df['close'].rolling(window=20).mean()
92
+ df['BB_upper'] = df['MA20'] + 2 * df['close'].rolling(window=20).std()
93
+ df['BB_lower'] = df['MA20'] - 2 * df['close'].rolling(window=20).std()
94
+
95
+ return df
96
+
97
+ def create_technical_charts(df):
98
+ if df.empty:
99
+ return None, None, None
100
+
101
+ fig1 = go.Figure()
102
+ fig1.add_trace(go.Candlestick(
103
+ x=df['timestamp'],
104
+ open=df['open'],
105
+ high=df['high'],
106
+ low=df['low'],
107
+ close=df['close'],
108
+ name='Price'
109
+ ))
110
+ fig1.add_trace(go.Scatter(x=df['timestamp'], y=df['BB_upper'], name='Upper BB', line=dict(color='gray', dash='dash')))
111
+ fig1.add_trace(go.Scatter(x=df['timestamp'], y=df['BB_lower'], name='Lower BB', line=dict(color='gray', dash='dash')))
112
+ fig1.update_layout(title='Price and Bollinger Bands', xaxis_title='Date', yaxis_title='Price')
113
+
114
+ fig2 = go.Figure()
115
+ fig2.add_trace(go.Scatter(x=df['timestamp'], y=df['RSI'], name='RSI'))
116
+ fig2.add_hline(y=70, line_dash="dash", line_color="red")
117
+ fig2.add_hline(y=30, line_dash="dash", line_color="green")
118
+ fig2.update_layout(title='RSI Indicator', xaxis_title='Date', yaxis_title='RSI')
119
+
120
+ fig3 = go.Figure()
121
+ fig3.add_trace(go.Scatter(x=df['timestamp'], y=df['MACD'], name='MACD'))
122
+ fig3.add_trace(go.Scatter(x=df['timestamp'], y=df['Signal_Line'], name='Signal Line'))
123
+ fig3.update_layout(title='MACD', xaxis_title='Date', yaxis_title='Value')
124
+
125
+ return fig1, fig2, fig3
126
+
127
+ # --- Prophet Forecasting Functions ---
128
+ def prepare_data_for_prophet(df):
129
+ if df.empty:
130
+ return pd.DataFrame(columns=["ds", "y"])
131
+ df_prophet = df.rename(columns={"timestamp": "ds", "close": "y"})
132
+ return df_prophet[["ds", "y"]]
133
+
134
+ def prophet_forecast(df_prophet, periods=10, freq="h", daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=False, seasonality_mode="additive", changepoint_prior_scale=0.05):
135
+ if df_prophet.empty:
136
+ return pd.DataFrame(), "No data for Prophet."
137
+
138
+ try:
139
+ model = Prophet(
140
+ daily_seasonality=daily_seasonality,
141
+ weekly_seasonality=weekly_seasonality,
142
+ yearly_seasonality=yearly_seasonality,
143
+ seasonality_mode=seasonality_mode,
144
+ changepoint_prior_scale=changepoint_prior_scale,
145
+ )
146
+ model.fit(df_prophet)
147
+ future = model.make_future_dataframe(periods=periods, freq=freq)
148
+ forecast = model.predict(future)
149
+ return forecast, ""
150
+ except Exception as e:
151
+ return pd.DataFrame(), f"Forecast error: {e}"
152
+
153
+ def prophet_wrapper(df_prophet, forecast_steps, freq, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale):
154
+ if len(df_prophet) < 10:
155
+ return pd.DataFrame(), "Not enough data for forecasting (need >=10 rows)."
156
+
157
+ full_forecast, err = prophet_forecast(
158
+ df_prophet,
159
+ forecast_steps,
160
+ freq,
161
+ daily_seasonality,
162
+ weekly_seasonality,
163
+ yearly_seasonality,
164
+ seasonality_mode,
165
+ changepoint_prior_scale,
166
+ )
167
+ if err:
168
+ return pd.DataFrame(), err
169
+
170
+ future_only = full_forecast.loc[len(df_prophet):, ["ds", "yhat", "yhat_lower", "yhat_upper"]]
171
+ return future_only, ""
172
+
173
+ def create_forecast_plot(forecast_df):
174
+ if forecast_df.empty:
175
+ return go.Figure()
176
+
177
+ fig = go.Figure()
178
+ fig.add_trace(go.Scatter(
179
+ x=forecast_df["ds"],
180
+ y=forecast_df["yhat"],
181
+ mode="lines",
182
+ name="Forecast",
183
+ line=dict(color="blue", width=2)
184
+ ))
185
+
186
+ fig.add_trace(go.Scatter(
187
+ x=forecast_df["ds"],
188
+ y=forecast_df["yhat_lower"],
189
+ fill=None,
190
+ mode="lines",
191
+ line=dict(width=0),
192
+ showlegend=True,
193
+ name="Lower Bound"
194
+ ))
195
+
196
+ fig.add_trace(go.Scatter(
197
+ x=forecast_df["ds"],
198
+ y=forecast_df["yhat_upper"],
199
+ fill="tonexty",
200
+ mode="lines",
201
+ line=dict(width=0),
202
+ name="Upper Bound"
203
+ ))
204
+
205
+ fig.update_layout(
206
+ title="Price Forecast",
207
+ xaxis_title="Time",
208
+ yaxis_title="Price",
209
+ hovermode="x unified",
210
+ template="plotly_white",
211
+ )
212
+ return fig
213
+
214
+ # --- Model Training and Prediction ---
215
+ model = RandomForestClassifier(**RANDOM_FOREST_PARAMS)
216
+
217
+ def train_model(df):
218
+ if df.empty:
219
+ return
220
+ df["target"] = (df["close"].pct_change() > 0.05).astype(int)
221
+ features = df[["close", "volume"]].dropna()
222
+ target = df["target"].dropna()
223
+ if not features.empty and not target.empty:
224
+ model.fit(features, target)
225
+ else:
226
+ print("Not enough data for model training.")
227
+
228
+ def predict_growth(latest_data):
229
+ if not hasattr(model, 'estimators_') or len(model.estimators_) == 0:
230
+ return [0]
231
+
232
+ try:
233
+ prediction = model.predict(latest_data.reshape(1, -1))
234
+ return prediction
235
+ except Exception as e:
236
+ print(f"Prediction error: {e}")
237
+ return [0]
238
+
239
+ # --- Main Prediction and Display Function ---
240
+ def analyze_market(market_type, symbol, interval, forecast_steps, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale, sentiment_keyword=""):
241
+ df = pd.DataFrame()
242
+ error_message = ""
243
+ sentiment_score = 0
244
+
245
+ try:
246
+ if market_type == "Crypto":
247
+ df = fetch_crypto_data(symbol, interval=interval)
248
+ elif market_type == "Stock":
249
+ df = fetch_stock_data(symbol, interval=interval)
250
+ else:
251
+ error_message = "Invalid market type selected."
252
+ return None, None, None, None, None, "", error_message, 0
253
+
254
+ if sentiment_keyword:
255
+ sentiment_score = fetch_sentiment_data(sentiment_keyword)
256
+ except Exception as e:
257
+ error_message = f"Data Fetching Error: {e}"
258
+ return None, None, None, None, None, "", error_message, 0
259
+
260
+ if df.empty:
261
+ error_message = "No data fetched."
262
+ return None, None, None, None, None, "", error_message, 0
263
+
264
+ df["timestamp"] = pd.to_datetime(df["timestamp"])
265
+ numeric_cols = ["open", "high", "low", "close", "volume"]
266
+ df[numeric_cols] = df[numeric_cols].astype(float)
267
+ df = calculate_technical_indicators(df)
268
+
269
+ df_prophet = prepare_data_for_prophet(df)
270
+ freq = "h" if interval == "1h" or interval == "60min" else "d"
271
+ forecast_df, prophet_error = prophet_wrapper(
272
+ df_prophet,
273
+ forecast_steps,
274
+ freq,
275
+ daily_seasonality,
276
+ weekly_seasonality,
277
+ yearly_seasonality,
278
+ seasonality_mode,
279
+ changepoint_prior_scale,
280
+ )
281
+
282
+ if prophet_error:
283
+ error_message = f"Prophet Error: {prophet_error}"
284
+ return None, None, None, None, None, "", error_message, sentiment_score
285
+
286
+ forecast_plot = create_forecast_plot(forecast_df)
287
+ tech_plot, rsi_plot, macd_plot = create_technical_charts(df)
288
+
289
+ try:
290
+ train_model(df.copy())
291
+ if not df.empty:
292
+ latest_data = df[["close", "volume"]].iloc[-1].values
293
+ growth_prediction = predict_growth(latest_data)
294
+ growth_label = "Yes" if growth_prediction[0] == 1 else "No"
295
+ else:
296
+ growth_label = "N/A: Insufficient Data"
297
+ except Exception as e:
298
+ error_message = f"Model Error: {e}"
299
+ growth_label = "N/A"
300
+
301
+ forecast_df_display = forecast_df.loc[:, ["ds", "yhat", "yhat_lower", "yhat_upper"]].copy()
302
+ forecast_df_display.rename(columns={"ds": "Date", "yhat": "Forecast", "yhat_lower": "Lower Bound", "yhat_upper": "Upper Bound"}, inplace=True)
303
+ return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df_display, growth_label, error_message, sentiment_score
304
+
305
+ # --- Gradio Interface ---
306
+ with gr.Blocks(theme=gr.themes.Base()) as demo:
307
+ gr.Markdown("# Market Analysis and Prediction")
308
+
309
+ with gr.Row():
310
+ with gr.Column():
311
+ market_type_dd = gr.Radio(label="Market Type", choices=["Crypto", "Stock"], value="Crypto")
312
+ symbol_dd = gr.Dropdown(label="Symbol", choices=CRYPTO_SYMBOLS, value="BTC-USD")
313
+ interval_dd = gr.Dropdown(label="Interval", choices=INTERVAL_OPTIONS, value="1h")
314
+ forecast_steps_slider = gr.Slider(label="Forecast Steps", minimum=1, maximum=100, value=DEFAULT_FORECAST_STEPS, step=1)
315
+ daily_box = gr.Checkbox(label="Daily Seasonality", value=DEFAULT_DAILY_SEASONALITY)
316
+ weekly_box = gr.Checkbox(label="Weekly Seasonality", value=DEFAULT_WEEKLY_SEASONALITY)
317
+ yearly_box = gr.Checkbox(label="Yearly Seasonality", value=DEFAULT_YEARLY_SEASONALITY)
318
+ seasonality_mode_dd = gr.Dropdown(label="Seasonality Mode", choices=["additive", "multiplicative"], value=DEFAULT_SEASONALITY_MODE)
319
+ changepoint_scale_slider = gr.Slider(label="Changepoint Prior Scale", minimum=0.01, maximum=1.0, step=0.01, value=DEFAULT_CHANGEPOINT_PRIOR_SCALE)
320
+ sentiment_keyword_txt = gr.Textbox(label="Sentiment Keyword (optional)")
321
+
322
+ with gr.Column():
323
+ forecast_plot = gr.Plot(label="Price Forecast")
324
+ with gr.Row():
325
+ tech_plot = gr.Plot(label="Technical Analysis")
326
+ rsi_plot = gr.Plot(label="RSI Indicator")
327
+ with gr.Row():
328
+ macd_plot = gr.Plot(label="MACD")
329
+ forecast_df = gr.Dataframe(label="Forecast Data", headers=["Date", "Forecast", "Lower Bound", "Upper Bound"])
330
+ growth_label_output = gr.Label(label="Explosive Growth Prediction")
331
+ sentiment_label_output = gr.Number(label="Sentiment Score")
332
+
333
+ def update_symbol_choices(market_type):
334
+ if market_type == "Crypto":
335
+ return gr.Dropdown(choices=CRYPTO_SYMBOLS, value="BTC-USD")
336
+ elif market_type == "Stock":
337
+ return gr.Dropdown(choices=STOCK_SYMBOLS, value="AAPL")
338
+ return gr.Dropdown(choices=[], value=None)
339
+ market_type_dd.change(fn=update_symbol_choices, inputs=[market_type_dd], outputs=[symbol_dd])
340
+
341
+ analyze_button = gr.Button("Analyze Market", variant="primary")
342
+ analyze_button.click(
343
+ fn=analyze_market,
344
+ inputs=[
345
+ market_type_dd,
346
+ symbol_dd,
347
+ interval_dd,
348
+ forecast_steps_slider,
349
+ daily_box,
350
+ weekly_box,
351
+ yearly_box,
352
+ seasonality_mode_dd,
353
+ changepoint_scale_slider,
354
+ sentiment_keyword_txt,
355
+ ],
356
+ outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df, growth_label_output, gr.Label(label="Error Message"), sentiment_label_output]
357
+ )
358
 
359
  if __name__ == "__main__":
360
+ demo.launch()