|
import gradio as gr |
|
import ccxt |
|
import pandas as pd |
|
from ta import add_all_ta_features |
|
import time |
|
import torch |
|
import json |
|
import numpy as np |
|
from sklearn.preprocessing import StandardScaler |
|
from huggingface_hub import hf_hub_download |
|
import matplotlib.pyplot as plt |
|
import io |
|
from PIL import Image |
|
import datetime |
|
|
|
|
|
predictions_history = [] |
|
timestamps_history = [] |
|
|
|
def gradio_interface(): |
|
def process_data(): |
|
try: |
|
df = fetch_data() |
|
df = calculate_all_indicators(df) |
|
row = df.iloc[-2] |
|
last_row_filtered = row[selected_columns].fillna(0).values.tolist() |
|
prediction = predict(last_row_filtered) |
|
|
|
result = (f"Close: {prediction['Predicted Target']:.2f}, " |
|
f"Open: {prediction['Predicted Target Open']:.2f}, " |
|
f"High: {prediction['Predicted Target High']:.2f}, " |
|
f"Low: {prediction['Predicted Target Low']:.2f}") |
|
|
|
plot_img = plot_predictions(prediction) |
|
|
|
return result, plot_img |
|
except Exception as e: |
|
return f"Error: {str(e)}", None |
|
|
|
def predict_interface(): |
|
result, plot_img = process_data() |
|
return result, plot_img |
|
|
|
def plot_predictions(predictions): |
|
|
|
current_time = datetime.datetime.now() |
|
|
|
current_prediction = [predictions['Predicted Target'], predictions['Predicted Target Open'], |
|
predictions['Predicted Target High'], predictions['Predicted Target Low']] |
|
|
|
if not predictions_history or current_prediction != predictions_history[-1]: |
|
timestamps_history.append(current_time + datetime.timedelta(minutes=15)) |
|
predictions_history.append(current_prediction) |
|
|
|
if len(timestamps_history) > 10: |
|
timestamps_history.pop(0) |
|
predictions_history.pop(0) |
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
|
for idx, timestamp in enumerate(timestamps_history): |
|
prediction = predictions_history[idx] |
|
plt.scatter([timestamp] * 4, prediction, color=['green', 'red', 'blue', 'purple'], s=50) |
|
|
|
plt.title('Predicted Candles for BTC/USDT Futures', fontsize=14) |
|
plt.xlabel('Time', fontsize=12) |
|
plt.ylabel('Price', fontsize=12) |
|
plt.xticks(rotation=45) |
|
plt.grid(True) |
|
plt.tight_layout() |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png") |
|
buf.seek(0) |
|
img = Image.open(buf) |
|
return img |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("## Prediction Interface") |
|
gr.Markdown("This interface will give you 15 mins prediction of BTC/USD futures value") |
|
|
|
output = gr.Textbox(label="Prediction Result") |
|
plot_output = gr.Image(label="Prediction Plot", type="pil") |
|
button = gr.Button("Get Prediction") |
|
button.click(fn=predict_interface, inputs=[], outputs=[output, plot_output]) |
|
|
|
app.launch(show_api=False, auth=None) |
|
|
|
if __name__ == "__main__": |
|
|
|
def fetch_data(symbol="BTC/USDT", timeframe="1m", limit=500): |
|
exchange = ccxt.binanceus({ |
|
"rateLimit": 1200, |
|
"enableRateLimit": True, |
|
}) |
|
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, limit=limit) |
|
df = pd.DataFrame(ohlcv, columns=["timestamp", "Open", "High", "Low", "Close", "Volume"]) |
|
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms") |
|
return df |
|
|
|
def calculate_all_indicators(data): |
|
data = add_all_ta_features( |
|
df=data, |
|
open="Open", |
|
high="High", |
|
low="Low", |
|
close="Close", |
|
volume="Volume", |
|
fillna=False |
|
) |
|
return data |
|
|
|
model_path = hf_hub_download(repo_id="alexandrlukashov/gru-model-time-series", filename="gru_model.pth") |
|
config_path = hf_hub_download(repo_id="alexandrlukashov/gru-model-time-series", filename="gru_config.json") |
|
|
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
|
|
class GRUModel(torch.nn.Module): |
|
def __init__(self, input_size, hidden_size, num_layers, output_size): |
|
super(GRUModel, self).__init__() |
|
self.hidden_size = hidden_size |
|
self.num_layers = num_layers |
|
self.gru = torch.nn.GRU(input_size, hidden_size, num_layers, batch_first=True) |
|
self.fc = torch.nn.Linear(hidden_size, output_size) |
|
|
|
def forward(self, x): |
|
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) |
|
out, _ = self.gru(x, h0) |
|
out = self.fc(out[:, -1, :]) |
|
return out |
|
|
|
model = GRUModel( |
|
input_size=config["input_size"], |
|
hidden_size=config["hidden_size"], |
|
num_layers=config["num_layers"], |
|
output_size=4 |
|
) |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) |
|
model.eval() |
|
|
|
scaler_path = hf_hub_download(repo_id="alexandrlukashov/gru-model-time-series", filename="scaler_X.json") |
|
|
|
with open(scaler_path, "r") as f: |
|
scaler_params = json.load(f) |
|
|
|
scaler = StandardScaler() |
|
scaler.mean_ = np.array(scaler_params["mean"]) |
|
scaler.scale_ = np.array(scaler_params["scale"]) |
|
scaler.var_ = scaler.scale_**2 |
|
|
|
def preprocess_input(data): |
|
data = np.array(data).reshape(1, -1) |
|
scaled_data = scaler.transform(data) |
|
return scaled_data |
|
|
|
def inverse_scale_output(predictions): |
|
dummy_input = np.zeros((1, len(scaler.mean_))) |
|
dummy_input[:, :4] = predictions |
|
unscaled_predictions = scaler.inverse_transform(dummy_input) |
|
return unscaled_predictions[0, :4] |
|
|
|
def predict(inputs): |
|
inputs = preprocess_input(inputs) |
|
inputs_tensor = torch.tensor(inputs, dtype=torch.float32).unsqueeze(1) |
|
with torch.no_grad(): |
|
predictions = model(inputs_tensor).numpy() |
|
predictions = inverse_scale_output(predictions) |
|
return { |
|
"Predicted Target": predictions[0], |
|
"Predicted Target Open": predictions[1], |
|
"Predicted Target High": predictions[2], |
|
"Predicted Target Low": predictions[3] |
|
} |
|
|
|
selected_columns = [ |
|
'Open', 'High', 'Low', 'Close', 'others_cr', 'trend_ema_fast', |
|
'trend_ichimoku_conv', 'momentum_kama', 'volatility_kcc', 'volume_vwap', |
|
'trend_sma_fast', 'trend_ichimoku_a', 'volatility_kch', 'volatility_kcl', |
|
'volatility_dcm', 'trend_ema_slow', 'volatility_bbm', 'trend_ichimoku_base', |
|
'trend_sma_slow', 'trend_psar_down', 'trend_psar_up', 'volatility_dch', |
|
'volatility_bbh', 'trend_ichimoku_b', 'volatility_dcl', 'volatility_bbl', |
|
'trend_visual_ichimoku_a', 'trend_visual_ichimoku_b' |
|
] |
|
|
|
gradio_interface() |
|
|