Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,79 +1,156 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
import yfinance as yf
|
|
|
|
|
|
|
3 |
import plotly.graph_objects as go
|
4 |
-
from statsmodels.tsa.arima.model import ARIMA
|
5 |
-
import pandas as pd
|
6 |
-
import logging
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
def
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
data = data[-200:]
|
36 |
-
|
37 |
-
return data, predict_steps, freq
|
38 |
|
39 |
-
def
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
logging.info("Model training completed.")
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
fig = go.Figure()
|
61 |
-
fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='ETH Price'))
|
62 |
-
fig.add_trace(go.Scatter(x=forecast_df.index, y=forecast_df['Prediction'], mode='lines', name='Prediction', line=dict(dash='dash', color='orange')))
|
63 |
-
fig.update_layout(title=f"ETH Price and Predictions ({period})", xaxis_title="Date", yaxis_title="Price (USD)")
|
64 |
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
return fig
|
67 |
|
68 |
-
def
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
|
77 |
-
|
78 |
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from datetime import datetime, timedelta
|
5 |
import yfinance as yf
|
6 |
+
from sklearn.preprocessing import MinMaxScaler
|
7 |
+
from tensorflow.keras.models import Sequential
|
8 |
+
from tensorflow.keras.layers import LSTM, Dense
|
9 |
import plotly.graph_objects as go
|
|
|
|
|
|
|
10 |
|
11 |
+
def fetch_ethereum_data():
|
12 |
+
"""
|
13 |
+
Fetch historical Ethereum price data using yfinance.
|
14 |
+
Returns DataFrame with date and price information.
|
15 |
+
"""
|
16 |
+
eth_ticker = yf.Ticker("ETH-USD")
|
17 |
+
# Get data for the past year
|
18 |
+
hist_data = eth_ticker.history(period="1y")
|
19 |
+
return hist_data[['Close']].reset_index()
|
20 |
|
21 |
+
def prepare_data(data, sequence_length=60):
|
22 |
+
"""
|
23 |
+
Prepare data for LSTM model by creating sequences and scaling.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
data: DataFrame with price data
|
27 |
+
sequence_length: Number of time steps to use for prediction
|
28 |
+
"""
|
29 |
+
# Scale the data
|
30 |
+
scaler = MinMaxScaler()
|
31 |
+
scaled_data = scaler.fit_transform(data['Close'].values.reshape(-1, 1))
|
32 |
+
|
33 |
+
# Create sequences for training
|
34 |
+
X, y = [], []
|
35 |
+
for i in range(sequence_length, len(scaled_data)):
|
36 |
+
X.append(scaled_data[i-sequence_length:i, 0])
|
37 |
+
y.append(scaled_data[i, 0])
|
38 |
+
|
39 |
+
X = np.array(X)
|
40 |
+
y = np.array(y)
|
41 |
+
|
42 |
+
# Reshape X for LSTM input
|
43 |
+
X = X.reshape(X.shape[0], X.shape[1], 1)
|
44 |
+
|
45 |
+
return X, y, scaler
|
|
|
|
|
|
|
46 |
|
47 |
+
def create_model(sequence_length):
|
48 |
+
"""
|
49 |
+
Create and compile LSTM model for time series prediction.
|
50 |
+
"""
|
51 |
+
model = Sequential([
|
52 |
+
LSTM(50, return_sequences=True, input_shape=(sequence_length, 1)),
|
53 |
+
LSTM(50, return_sequences=False),
|
54 |
+
Dense(25),
|
55 |
+
Dense(1)
|
56 |
+
])
|
57 |
+
|
58 |
+
model.compile(optimizer='adam', loss='mse')
|
59 |
+
return model
|
60 |
|
61 |
+
def predict_future_prices(model, last_sequence, scaler, days=7):
|
62 |
+
"""
|
63 |
+
Predict future prices using the trained model.
|
|
|
64 |
|
65 |
+
Args:
|
66 |
+
model: Trained LSTM model
|
67 |
+
last_sequence: Last sequence of known prices
|
68 |
+
scaler: Fitted MinMaxScaler
|
69 |
+
days: Number of days to predict
|
70 |
+
"""
|
71 |
+
future_predictions = []
|
72 |
+
current_sequence = last_sequence.copy()
|
73 |
|
74 |
+
for _ in range(days):
|
75 |
+
# Predict next price
|
76 |
+
scaled_prediction = model.predict(current_sequence.reshape(1, -1, 1))
|
77 |
+
# Inverse transform to get actual price
|
78 |
+
prediction = scaler.inverse_transform(scaled_prediction)[0][0]
|
79 |
+
future_predictions.append(prediction)
|
80 |
+
|
81 |
+
# Update sequence for next prediction
|
82 |
+
current_sequence = np.roll(current_sequence, -1)
|
83 |
+
current_sequence[-1] = scaled_prediction
|
84 |
|
85 |
+
return future_predictions
|
86 |
+
|
87 |
+
def create_prediction_plot(historical_data, future_predictions, future_dates):
|
88 |
+
"""
|
89 |
+
Create an interactive plot showing historical prices and predictions.
|
90 |
+
"""
|
91 |
fig = go.Figure()
|
|
|
|
|
|
|
92 |
|
93 |
+
# Plot historical data
|
94 |
+
fig.add_trace(go.Scatter(
|
95 |
+
x=historical_data['Date'],
|
96 |
+
y=historical_data['Close'],
|
97 |
+
name='Historical Prices',
|
98 |
+
line=dict(color='blue')
|
99 |
+
))
|
100 |
+
|
101 |
+
# Plot predictions
|
102 |
+
fig.add_trace(go.Scatter(
|
103 |
+
x=future_dates,
|
104 |
+
y=future_predictions,
|
105 |
+
name='Predictions',
|
106 |
+
line=dict(color='red', dash='dash')
|
107 |
+
))
|
108 |
+
|
109 |
+
fig.update_layout(
|
110 |
+
title='Ethereum Price Prediction',
|
111 |
+
xaxis_title='Date',
|
112 |
+
yaxis_title='Price (USD)',
|
113 |
+
hovermode='x unified'
|
114 |
+
)
|
115 |
+
|
116 |
return fig
|
117 |
|
118 |
+
def predict_ethereum():
|
119 |
+
"""
|
120 |
+
Main function for Gradio interface that orchestrates the prediction process.
|
121 |
+
"""
|
122 |
+
# Fetch and prepare data
|
123 |
+
data = fetch_ethereum_data()
|
124 |
+
sequence_length = 60
|
125 |
+
X, y, scaler = prepare_data(data, sequence_length)
|
126 |
+
|
127 |
+
# Create and train model
|
128 |
+
model = create_model(sequence_length)
|
129 |
+
model.fit(X, y, epochs=50, batch_size=32, verbose=0)
|
130 |
|
131 |
+
# Prepare last sequence for prediction
|
132 |
+
last_sequence = scaler.transform(data['Close'].values[-sequence_length:].reshape(-1, 1))
|
133 |
|
134 |
+
# Generate future predictions
|
135 |
+
future_predictions = predict_future_prices(model, last_sequence, scaler)
|
136 |
+
|
137 |
+
# Create future dates
|
138 |
+
last_date = data['Date'].iloc[-1]
|
139 |
+
future_dates = [last_date + timedelta(days=i+1) for i in range(7)]
|
140 |
+
|
141 |
+
# Create and return plot
|
142 |
+
fig = create_prediction_plot(data, future_predictions, future_dates)
|
143 |
+
return fig
|
144 |
+
|
145 |
+
# Create Gradio interface
|
146 |
+
iface = gr.Interface(
|
147 |
+
fn=predict_ethereum,
|
148 |
+
inputs=None,
|
149 |
+
outputs=gr.Plot(),
|
150 |
+
title="Ethereum Price Prediction",
|
151 |
+
description="Click to generate a 7-day price prediction for Ethereum based on historical data.",
|
152 |
+
theme=gr.themes.Base()
|
153 |
+
)
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
iface.launch()
|