saneowl commited on
Commit
0351882
·
verified ·
1 Parent(s): 39dad61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -22
app.py CHANGED
@@ -11,20 +11,22 @@ import plotly.graph_objects as go
11
  def fetch_ethereum_data():
12
  """
13
  Fetch historical Ethereum price data using yfinance.
14
- Returns DataFrame with date and price information.
 
15
  """
16
  eth_ticker = yf.Ticker("ETH-USD")
17
- # Get data for the past week
18
  hist_data = eth_ticker.history(period="7d", interval="1h")
19
- return hist_data[['Close']].reset_index()
 
20
 
21
  def prepare_data(data, sequence_length=24):
22
  """
23
  Prepare data for LSTM model by creating sequences and scaling.
24
 
25
  Args:
26
- data: DataFrame with price data
27
- sequence_length: Number of time steps to use for prediction
28
  """
29
  # Scale the data
30
  scaler = MinMaxScaler()
@@ -47,6 +49,7 @@ def prepare_data(data, sequence_length=24):
47
  def create_model(sequence_length):
48
  """
49
  Create and compile LSTM model for time series prediction.
 
50
  """
51
  model = Sequential([
52
  LSTM(50, return_sequences=True, input_shape=(sequence_length, 1)),
@@ -66,14 +69,17 @@ def predict_future_prices(model, last_sequence, scaler, days=7):
66
  model: Trained LSTM model
67
  last_sequence: Last sequence of known prices
68
  scaler: Fitted MinMaxScaler
69
- days: Number of days to predict
70
  """
71
  future_predictions = []
72
  current_sequence = last_sequence.copy()
73
 
74
- for _ in range(days):
 
 
 
75
  # Predict next price
76
- scaled_prediction = model.predict(current_sequence.reshape(1, -1, 1))
77
  # Inverse transform to get actual price
78
  prediction = scaler.inverse_transform(scaled_prediction)[0][0]
79
  future_predictions.append(prediction)
@@ -86,21 +92,19 @@ def predict_future_prices(model, last_sequence, scaler, days=7):
86
 
87
  def create_prediction_plot(historical_data, future_predictions, future_dates):
88
  """
89
- Create an interactive plot showing the last week of historical prices and week-ahead predictions.
 
90
 
91
  Args:
92
- historical_data: DataFrame with historical price data
93
  future_predictions: List of predicted prices
94
- future_dates: List of future dates for predictions
95
- """
96
- """
97
- Create an interactive plot showing historical prices and predictions.
98
  """
99
  fig = go.Figure()
100
 
101
- # Plot historical data
102
  fig.add_trace(go.Scatter(
103
- x=historical_data['Date'],
104
  y=historical_data['Close'],
105
  name='Historical Prices',
106
  line=dict(color='blue')
@@ -115,7 +119,7 @@ def create_prediction_plot(historical_data, future_predictions, future_dates):
115
  ))
116
 
117
  fig.update_layout(
118
- title='Ethereum Price Prediction',
119
  xaxis_title='Date',
120
  yaxis_title='Price (USD)',
121
  hovermode='x unified'
@@ -126,10 +130,11 @@ def create_prediction_plot(historical_data, future_predictions, future_dates):
126
  def predict_ethereum():
127
  """
128
  Main function for Gradio interface that orchestrates the prediction process.
 
129
  """
130
  # Fetch and prepare data
131
  data = fetch_ethereum_data()
132
- sequence_length = 24
133
  X, y, scaler = prepare_data(data, sequence_length)
134
 
135
  # Create and train model
@@ -142,9 +147,9 @@ def predict_ethereum():
142
  # Generate future predictions
143
  future_predictions = predict_future_prices(model, last_sequence, scaler)
144
 
145
- # Create future dates
146
- last_date = data['Date'].iloc[-1]
147
- future_dates = [last_date + timedelta(days=i+1) for i in range(7)]
148
 
149
  # Create and return plot
150
  fig = create_prediction_plot(data, future_predictions, future_dates)
@@ -156,7 +161,7 @@ iface = gr.Interface(
156
  inputs=None,
157
  outputs=gr.Plot(),
158
  title="Ethereum Price Prediction",
159
- description="Click to generate a 7-day price prediction for Ethereum based on historical data.",
160
  theme=gr.themes.Base()
161
  )
162
 
 
11
  def fetch_ethereum_data():
12
  """
13
  Fetch historical Ethereum price data using yfinance.
14
+ Returns DataFrame with datetime index and price information.
15
+ The data is sampled hourly for the past week.
16
  """
17
  eth_ticker = yf.Ticker("ETH-USD")
18
+ # Get hourly data for the past week
19
  hist_data = eth_ticker.history(period="7d", interval="1h")
20
+ # Keep the datetime index and Close price
21
+ return hist_data[['Close']]
22
 
23
  def prepare_data(data, sequence_length=24):
24
  """
25
  Prepare data for LSTM model by creating sequences and scaling.
26
 
27
  Args:
28
+ data: DataFrame with price data and datetime index
29
+ sequence_length: Number of time steps to use for prediction (default: 24 hours)
30
  """
31
  # Scale the data
32
  scaler = MinMaxScaler()
 
49
  def create_model(sequence_length):
50
  """
51
  Create and compile LSTM model for time series prediction.
52
+ Uses a two-layer LSTM architecture followed by dense layers.
53
  """
54
  model = Sequential([
55
  LSTM(50, return_sequences=True, input_shape=(sequence_length, 1)),
 
69
  model: Trained LSTM model
70
  last_sequence: Last sequence of known prices
71
  scaler: Fitted MinMaxScaler
72
+ days: Number of days to predict (default: 7)
73
  """
74
  future_predictions = []
75
  current_sequence = last_sequence.copy()
76
 
77
+ # Convert days to hours since we're using hourly data
78
+ hours = days * 24
79
+
80
+ for _ in range(hours):
81
  # Predict next price
82
+ scaled_prediction = model.predict(current_sequence.reshape(1, -1, 1), verbose=0)
83
  # Inverse transform to get actual price
84
  prediction = scaler.inverse_transform(scaled_prediction)[0][0]
85
  future_predictions.append(prediction)
 
92
 
93
  def create_prediction_plot(historical_data, future_predictions, future_dates):
94
  """
95
+ Create an interactive plot showing the last week of historical prices
96
+ and week-ahead predictions with hourly granularity.
97
 
98
  Args:
99
+ historical_data: DataFrame with historical price data and datetime index
100
  future_predictions: List of predicted prices
101
+ future_dates: List of future datetime indices for predictions
 
 
 
102
  """
103
  fig = go.Figure()
104
 
105
+ # Plot historical data using the datetime index
106
  fig.add_trace(go.Scatter(
107
+ x=historical_data.index,
108
  y=historical_data['Close'],
109
  name='Historical Prices',
110
  line=dict(color='blue')
 
119
  ))
120
 
121
  fig.update_layout(
122
+ title='Ethereum Price Prediction (Hourly)',
123
  xaxis_title='Date',
124
  yaxis_title='Price (USD)',
125
  hovermode='x unified'
 
130
  def predict_ethereum():
131
  """
132
  Main function for Gradio interface that orchestrates the prediction process.
133
+ Handles hourly data and generates predictions for the next week.
134
  """
135
  # Fetch and prepare data
136
  data = fetch_ethereum_data()
137
+ sequence_length = 24 # Use 24 hours of data for prediction
138
  X, y, scaler = prepare_data(data, sequence_length)
139
 
140
  # Create and train model
 
147
  # Generate future predictions
148
  future_predictions = predict_future_prices(model, last_sequence, scaler)
149
 
150
+ # Create future dates (hourly intervals)
151
+ last_date = data.index[-1]
152
+ future_dates = [last_date + timedelta(hours=i+1) for i in range(len(future_predictions))]
153
 
154
  # Create and return plot
155
  fig = create_prediction_plot(data, future_predictions, future_dates)
 
161
  inputs=None,
162
  outputs=gr.Plot(),
163
  title="Ethereum Price Prediction",
164
+ description="Click to generate a 7-day price prediction for Ethereum based on hourly historical data.",
165
  theme=gr.themes.Base()
166
  )
167