saneowl commited on
Commit
4e5436b
·
verified ·
1 Parent(s): dbc69f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -63
app.py CHANGED
@@ -1,79 +1,156 @@
1
  import gradio as gr
 
 
 
2
  import yfinance as yf
 
 
 
3
  import plotly.graph_objects as go
4
- from statsmodels.tsa.arima.model import ARIMA
5
- import pandas as pd
6
- import logging
7
 
8
- logging.basicConfig(level=logging.INFO)
 
 
 
 
 
 
 
 
9
 
10
- def fetch_eth_price(period):
11
- eth = yf.Ticker("ETH-USD")
12
- if period == '1d':
13
- data = eth.history(period="1d", interval="1m")
14
- predict_steps = 60 # Next 60 minutes
15
- freq = 'min' # Minute frequency
16
- elif period == '5d':
17
- data = eth.history(period="5d", interval="15m")
18
- predict_steps = 96 # Next 24 hours
19
- freq = '15min' # 15 minutes frequency
20
- elif period == '1wk':
21
- data = eth.history(period="1wk", interval="30m")
22
- predict_steps = 336 # Next 7 days
23
- freq = '30min' # 30 minutes frequency
24
- elif period == '1mo':
25
- data = eth.history(period="1mo", interval="1h")
26
- predict_steps = 720 # Next 30 days
27
- freq = 'H' # Hourly frequency
28
- else:
29
- return None, None, None
30
-
31
- data.index = pd.DatetimeIndex(data.index)
32
- data = data.asfreq(freq) # Ensure the data has a consistent frequency
33
-
34
- # Limit the data to the last 200 points to reduce prediction time
35
- data = data[-200:]
36
-
37
- return data, predict_steps, freq
38
 
39
- def make_predictions(data, predict_steps, freq):
40
- if data is None or data.empty:
41
- logging.error("No data available for prediction.")
42
- return pd.DataFrame(index=pd.date_range(start=pd.Timestamp.now(), periods=predict_steps+1, freq=freq)[1:])
 
 
 
 
 
 
 
 
 
43
 
44
- logging.info(f"Starting model training with {len(data)} data points...")
45
- model = ARIMA(data['Close'], order=(5, 1, 0))
46
- model_fit = model.fit()
47
- logging.info("Model training completed.")
48
 
49
- forecast = model_fit.forecast(steps=predict_steps)
50
- future_dates = pd.date_range(start=data.index[-1], periods=predict_steps+1, freq=freq, inclusive='right')
51
- forecast_df = pd.DataFrame(forecast, index=future_dates[1:], columns=['Prediction'])
 
 
 
 
 
52
 
53
- logging.info("Predictions generated successfully.")
54
- return forecast_df
55
-
56
- def plot_eth(period):
57
- data, predict_steps, freq = fetch_eth_price(period)
58
- forecast_df = make_predictions(data, predict_steps, freq)
 
 
 
 
59
 
 
 
 
 
 
 
60
  fig = go.Figure()
61
- fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='ETH Price'))
62
- fig.add_trace(go.Scatter(x=forecast_df.index, y=forecast_df['Prediction'], mode='lines', name='Prediction', line=dict(dash='dash', color='orange')))
63
- fig.update_layout(title=f"ETH Price and Predictions ({period})", xaxis_title="Date", yaxis_title="Price (USD)")
64
 
65
- logging.info("Plotting completed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  return fig
67
 
68
- def refresh_predictions(period):
69
- return plot_eth(period)
70
-
71
- with gr.Blocks() as iface:
72
- period = gr.Radio(["1d", "5d", "1wk", "1mo"], label="Select Period")
73
- plot = gr.Plot()
74
- refresh_button = gr.Button("Refresh Predictions and Prices")
 
 
 
 
 
75
 
76
- period.change(fn=plot_eth, inputs=period, outputs=plot)
77
- refresh_button.click(fn=refresh_predictions, inputs=period, outputs=plot)
78
 
79
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ from datetime import datetime, timedelta
5
  import yfinance as yf
6
+ from sklearn.preprocessing import MinMaxScaler
7
+ from tensorflow.keras.models import Sequential
8
+ from tensorflow.keras.layers import LSTM, Dense
9
  import plotly.graph_objects as go
 
 
 
10
 
11
+ def fetch_ethereum_data():
12
+ """
13
+ Fetch historical Ethereum price data using yfinance.
14
+ Returns DataFrame with date and price information.
15
+ """
16
+ eth_ticker = yf.Ticker("ETH-USD")
17
+ # Get data for the past year
18
+ hist_data = eth_ticker.history(period="1y")
19
+ return hist_data[['Close']].reset_index()
20
 
21
+ def prepare_data(data, sequence_length=60):
22
+ """
23
+ Prepare data for LSTM model by creating sequences and scaling.
24
+
25
+ Args:
26
+ data: DataFrame with price data
27
+ sequence_length: Number of time steps to use for prediction
28
+ """
29
+ # Scale the data
30
+ scaler = MinMaxScaler()
31
+ scaled_data = scaler.fit_transform(data['Close'].values.reshape(-1, 1))
32
+
33
+ # Create sequences for training
34
+ X, y = [], []
35
+ for i in range(sequence_length, len(scaled_data)):
36
+ X.append(scaled_data[i-sequence_length:i, 0])
37
+ y.append(scaled_data[i, 0])
38
+
39
+ X = np.array(X)
40
+ y = np.array(y)
41
+
42
+ # Reshape X for LSTM input
43
+ X = X.reshape(X.shape[0], X.shape[1], 1)
44
+
45
+ return X, y, scaler
 
 
 
46
 
47
+ def create_model(sequence_length):
48
+ """
49
+ Create and compile LSTM model for time series prediction.
50
+ """
51
+ model = Sequential([
52
+ LSTM(50, return_sequences=True, input_shape=(sequence_length, 1)),
53
+ LSTM(50, return_sequences=False),
54
+ Dense(25),
55
+ Dense(1)
56
+ ])
57
+
58
+ model.compile(optimizer='adam', loss='mse')
59
+ return model
60
 
61
+ def predict_future_prices(model, last_sequence, scaler, days=7):
62
+ """
63
+ Predict future prices using the trained model.
 
64
 
65
+ Args:
66
+ model: Trained LSTM model
67
+ last_sequence: Last sequence of known prices
68
+ scaler: Fitted MinMaxScaler
69
+ days: Number of days to predict
70
+ """
71
+ future_predictions = []
72
+ current_sequence = last_sequence.copy()
73
 
74
+ for _ in range(days):
75
+ # Predict next price
76
+ scaled_prediction = model.predict(current_sequence.reshape(1, -1, 1))
77
+ # Inverse transform to get actual price
78
+ prediction = scaler.inverse_transform(scaled_prediction)[0][0]
79
+ future_predictions.append(prediction)
80
+
81
+ # Update sequence for next prediction
82
+ current_sequence = np.roll(current_sequence, -1)
83
+ current_sequence[-1] = scaled_prediction
84
 
85
+ return future_predictions
86
+
87
+ def create_prediction_plot(historical_data, future_predictions, future_dates):
88
+ """
89
+ Create an interactive plot showing historical prices and predictions.
90
+ """
91
  fig = go.Figure()
 
 
 
92
 
93
+ # Plot historical data
94
+ fig.add_trace(go.Scatter(
95
+ x=historical_data['Date'],
96
+ y=historical_data['Close'],
97
+ name='Historical Prices',
98
+ line=dict(color='blue')
99
+ ))
100
+
101
+ # Plot predictions
102
+ fig.add_trace(go.Scatter(
103
+ x=future_dates,
104
+ y=future_predictions,
105
+ name='Predictions',
106
+ line=dict(color='red', dash='dash')
107
+ ))
108
+
109
+ fig.update_layout(
110
+ title='Ethereum Price Prediction',
111
+ xaxis_title='Date',
112
+ yaxis_title='Price (USD)',
113
+ hovermode='x unified'
114
+ )
115
+
116
  return fig
117
 
118
+ def predict_ethereum():
119
+ """
120
+ Main function for Gradio interface that orchestrates the prediction process.
121
+ """
122
+ # Fetch and prepare data
123
+ data = fetch_ethereum_data()
124
+ sequence_length = 60
125
+ X, y, scaler = prepare_data(data, sequence_length)
126
+
127
+ # Create and train model
128
+ model = create_model(sequence_length)
129
+ model.fit(X, y, epochs=50, batch_size=32, verbose=0)
130
 
131
+ # Prepare last sequence for prediction
132
+ last_sequence = scaler.transform(data['Close'].values[-sequence_length:].reshape(-1, 1))
133
 
134
+ # Generate future predictions
135
+ future_predictions = predict_future_prices(model, last_sequence, scaler)
136
+
137
+ # Create future dates
138
+ last_date = data['Date'].iloc[-1]
139
+ future_dates = [last_date + timedelta(days=i+1) for i in range(7)]
140
+
141
+ # Create and return plot
142
+ fig = create_prediction_plot(data, future_predictions, future_dates)
143
+ return fig
144
+
145
+ # Create Gradio interface
146
+ iface = gr.Interface(
147
+ fn=predict_ethereum,
148
+ inputs=None,
149
+ outputs=gr.Plot(),
150
+ title="Ethereum Price Prediction",
151
+ description="Click to generate a 7-day price prediction for Ethereum based on historical data.",
152
+ theme=gr.themes.Base()
153
+ )
154
+
155
+ if __name__ == "__main__":
156
+ iface.launch()