Spaces:
Sleeping
Sleeping
import gradio as gr | |
import ccxt | |
import pandas as pd | |
from ta import add_all_ta_features | |
import time | |
import torch | |
import json | |
import numpy as np | |
from sklearn.preprocessing import StandardScaler | |
from huggingface_hub import hf_hub_download | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
import datetime | |
predictions_history = [] | |
timestamps_history = [] | |
def gradio_interface(): | |
def process_data(): | |
try: | |
df = fetch_data() | |
df = calculate_all_indicators(df) | |
row = df.iloc[-2] | |
last_row_filtered = row[selected_columns].fillna(0).values.tolist() | |
prediction = predict(last_row_filtered) | |
result = (f"Close: {prediction['Predicted Target']:.2f}, " | |
f"Open: {prediction['Predicted Target Open']:.2f}, " | |
f"High: {prediction['Predicted Target High']:.2f}, " | |
f"Low: {prediction['Predicted Target Low']:.2f}") | |
plot_img = plot_predictions(prediction) | |
return result, plot_img | |
except Exception as e: | |
return f"Error: {str(e)}", None | |
def predict_interface(): | |
result, plot_img = process_data() | |
return result, plot_img | |
def plot_predictions(predictions): | |
current_time = datetime.datetime.now() | |
current_prediction = [predictions['Predicted Target'], predictions['Predicted Target Open'], | |
predictions['Predicted Target High'], predictions['Predicted Target Low']] | |
if not predictions_history or current_prediction != predictions_history[-1]: | |
timestamps_history.append(current_time + datetime.timedelta(minutes=15)) | |
predictions_history.append(current_prediction) | |
if len(timestamps_history) > 10: | |
timestamps_history.pop(0) | |
predictions_history.pop(0) | |
plt.figure(figsize=(12, 8)) | |
labels = ['Close', 'Open', 'High', 'Low'] | |
colors = ['green', 'red', 'blue', 'purple'] | |
for i, (label, color) in enumerate(zip(labels, colors)): | |
y_values = [history[i] for history in predictions_history] | |
plt.plot(timestamps_history, y_values, color=color, label=label, marker='o') | |
if timestamps_history: | |
plt.text( | |
timestamps_history[-1], y_values[-1], f'{label}: {y_values[-1]:.2f}', | |
color=color, fontsize=10, ha='left', va='bottom' | |
) | |
plt.title('Predicted Candles for BTC/USDT Futures', fontsize=16) | |
plt.xlabel('Time', fontsize=14) | |
plt.ylabel('Price', fontsize=14) | |
plt.xticks(rotation=45) | |
plt.grid(True) | |
plt.legend(fontsize=12) | |
plt.tight_layout() | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png") | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
with gr.Blocks() as app: | |
gr.Markdown("## Prediction Interface") | |
gr.Markdown("This interface will give you 15 mins prediction of BTC/USD futures value") | |
output = gr.Textbox(label="Prediction Result") | |
plot_output = gr.Image(label="Prediction Plot", type="pil") | |
button = gr.Button("Get Prediction") | |
button.click(fn=predict_interface, inputs=[], outputs=[output, plot_output]) | |
app.launch(show_api=False, auth=None) | |
if __name__ == "__main__": | |
def fetch_data(symbol="BTC/USDT", timeframe="1m", limit=500): | |
exchange = ccxt.binanceus({ | |
"rateLimit": 1200, | |
"enableRateLimit": True, | |
}) | |
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, limit=limit) | |
df = pd.DataFrame(ohlcv, columns=["timestamp", "Open", "High", "Low", "Close", "Volume"]) | |
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms") | |
return df | |
def calculate_all_indicators(data): | |
data = add_all_ta_features( | |
df=data, | |
open="Open", | |
high="High", | |
low="Low", | |
close="Close", | |
volume="Volume", | |
fillna=False | |
) | |
return data | |
model_path = hf_hub_download(repo_id="alexandrlukashov/gru-model-time-series", filename="gru_model.pth") | |
config_path = hf_hub_download(repo_id="alexandrlukashov/gru-model-time-series", filename="gru_config.json") | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
class GRUModel(torch.nn.Module): | |
def __init__(self, input_size, hidden_size, num_layers, output_size): | |
super(GRUModel, self).__init__() | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.gru = torch.nn.GRU(input_size, hidden_size, num_layers, batch_first=True) | |
self.fc = torch.nn.Linear(hidden_size, output_size) | |
def forward(self, x): | |
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) | |
out, _ = self.gru(x, h0) | |
out = self.fc(out[:, -1, :]) | |
return out | |
model = GRUModel( | |
input_size=config["input_size"], | |
hidden_size=config["hidden_size"], | |
num_layers=config["num_layers"], | |
output_size=4 | |
) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) | |
model.eval() | |
scaler_path = hf_hub_download(repo_id="alexandrlukashov/gru-model-time-series", filename="scaler_X.json") | |
with open(scaler_path, "r") as f: | |
scaler_params = json.load(f) | |
scaler = StandardScaler() | |
scaler.mean_ = np.array(scaler_params["mean"]) | |
scaler.scale_ = np.array(scaler_params["scale"]) | |
scaler.var_ = scaler.scale_**2 | |
def preprocess_input(data): | |
data = np.array(data).reshape(1, -1) | |
scaled_data = scaler.transform(data) | |
return scaled_data | |
def inverse_scale_output(predictions): | |
dummy_input = np.zeros((1, len(scaler.mean_))) | |
dummy_input[:, :4] = predictions | |
unscaled_predictions = scaler.inverse_transform(dummy_input) | |
return unscaled_predictions[0, :4] | |
def predict(inputs): | |
inputs = preprocess_input(inputs) | |
inputs_tensor = torch.tensor(inputs, dtype=torch.float32).unsqueeze(1) | |
with torch.no_grad(): | |
predictions = model(inputs_tensor).numpy() | |
predictions = inverse_scale_output(predictions) | |
return { | |
"Predicted Target": predictions[0], | |
"Predicted Target Open": predictions[1], | |
"Predicted Target High": predictions[2], | |
"Predicted Target Low": predictions[3] | |
} | |
selected_columns = [ | |
'Open', 'High', 'Low', 'Close', 'others_cr', 'trend_ema_fast', | |
'trend_ichimoku_conv', 'momentum_kama', 'volatility_kcc', 'volume_vwap', | |
'trend_sma_fast', 'trend_ichimoku_a', 'volatility_kch', 'volatility_kcl', | |
'volatility_dcm', 'trend_ema_slow', 'volatility_bbm', 'trend_ichimoku_base', | |
'trend_sma_slow', 'trend_psar_down', 'trend_psar_up', 'volatility_dch', | |
'volatility_bbh', 'trend_ichimoku_b', 'volatility_dcl', 'volatility_bbl', | |
'trend_visual_ichimoku_a', 'trend_visual_ichimoku_b' | |
] | |
gradio_interface() | |