import gradio as gr
import ccxt
import pandas as pd
from ta import add_all_ta_features
import time
import torch
import json
import numpy as np
from sklearn.preprocessing import StandardScaler
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import io
from PIL import Image
import datetime


predictions_history = []
timestamps_history = []

def gradio_interface():
    def process_data():
        try:
            df = fetch_data()
            df = calculate_all_indicators(df)
            row = df.iloc[-2]
            last_row_filtered = row[selected_columns].fillna(0).values.tolist()
            prediction = predict(last_row_filtered)

            result = (f"Close: {prediction['Predicted Target']:.2f}, "
                      f"Open: {prediction['Predicted Target Open']:.2f}, "
                      f"High: {prediction['Predicted Target High']:.2f}, "
                      f"Low: {prediction['Predicted Target Low']:.2f}")

            plot_img = plot_predictions(prediction)

            return result, plot_img
        except Exception as e:
            return f"Error: {str(e)}", None

    def predict_interface():
        result, plot_img = process_data()
        return result, plot_img

    def plot_predictions(predictions):

        current_time = datetime.datetime.now()

        current_prediction = [predictions['Predicted Target'], predictions['Predicted Target Open'],
                              predictions['Predicted Target High'], predictions['Predicted Target Low']]

        if not predictions_history or current_prediction != predictions_history[-1]:
            timestamps_history.append(current_time + datetime.timedelta(minutes=15))
            predictions_history.append(current_prediction)

            if len(timestamps_history) > 10:
                timestamps_history.pop(0)
                predictions_history.pop(0)

        plt.figure(figsize=(10, 6))

        for idx, timestamp in enumerate(timestamps_history):
            prediction = predictions_history[idx]
            plt.scatter([timestamp] * 4, prediction, color=['green', 'red', 'blue', 'purple'], s=50)

        plt.title('Predicted Candles for BTC/USDT Futures', fontsize=14)
        plt.xlabel('Time', fontsize=12)
        plt.ylabel('Price', fontsize=12)
        plt.xticks(rotation=45)
        plt.grid(True)
        plt.tight_layout()

        buf = io.BytesIO()
        plt.savefig(buf, format="png")
        buf.seek(0)
        img = Image.open(buf)
        return img


    with gr.Blocks() as app:
        gr.Markdown("## Prediction Interface")
        gr.Markdown("This interface will give you 15 mins prediction of BTC/USD futures value")

        output = gr.Textbox(label="Prediction Result")
        plot_output = gr.Image(label="Prediction Plot", type="pil")
        button = gr.Button("Get Prediction")
        button.click(fn=predict_interface, inputs=[], outputs=[output, plot_output])

    app.launch(show_api=False, auth=None)

if __name__ == "__main__":

    def fetch_data(symbol="BTC/USDT", timeframe="1m", limit=500):
        exchange = ccxt.binanceus({
            "rateLimit": 1200,
            "enableRateLimit": True,
        })
        ohlcv = exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
        df = pd.DataFrame(ohlcv, columns=["timestamp", "Open", "High", "Low", "Close", "Volume"])
        df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
        return df

    def calculate_all_indicators(data):
        data = add_all_ta_features(
            df=data,
            open="Open",
            high="High",
            low="Low",
            close="Close",
            volume="Volume",
            fillna=False
        )
        return data

    model_path = hf_hub_download(repo_id="alexandrlukashov/gru-model-time-series", filename="gru_model.pth")
    config_path = hf_hub_download(repo_id="alexandrlukashov/gru-model-time-series", filename="gru_config.json")

    with open(config_path, "r") as f:
        config = json.load(f)

    class GRUModel(torch.nn.Module):
        def __init__(self, input_size, hidden_size, num_layers, output_size):
            super(GRUModel, self).__init__()
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.gru = torch.nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
            self.fc = torch.nn.Linear(hidden_size, output_size)

        def forward(self, x):
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            out, _ = self.gru(x, h0)
            out = self.fc(out[:, -1, :])
            return out

    model = GRUModel(
        input_size=config["input_size"],
        hidden_size=config["hidden_size"],
        num_layers=config["num_layers"],
        output_size=4
    )
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
    model.eval()

    scaler_path = hf_hub_download(repo_id="alexandrlukashov/gru-model-time-series", filename="scaler_X.json")

    with open(scaler_path, "r") as f:
        scaler_params = json.load(f)

    scaler = StandardScaler()
    scaler.mean_ = np.array(scaler_params["mean"])
    scaler.scale_ = np.array(scaler_params["scale"])
    scaler.var_ = scaler.scale_**2

    def preprocess_input(data):
        data = np.array(data).reshape(1, -1)
        scaled_data = scaler.transform(data)
        return scaled_data

    def inverse_scale_output(predictions):
        dummy_input = np.zeros((1, len(scaler.mean_)))
        dummy_input[:, :4] = predictions
        unscaled_predictions = scaler.inverse_transform(dummy_input)
        return unscaled_predictions[0, :4]

    def predict(inputs):
        inputs = preprocess_input(inputs)
        inputs_tensor = torch.tensor(inputs, dtype=torch.float32).unsqueeze(1)
        with torch.no_grad():
            predictions = model(inputs_tensor).numpy()
        predictions = inverse_scale_output(predictions)
        return {
            "Predicted Target": predictions[0],
            "Predicted Target Open": predictions[1],
            "Predicted Target High": predictions[2],
            "Predicted Target Low": predictions[3]
        }

    selected_columns = [
        'Open', 'High', 'Low', 'Close', 'others_cr', 'trend_ema_fast',
        'trend_ichimoku_conv', 'momentum_kama', 'volatility_kcc', 'volume_vwap',
        'trend_sma_fast', 'trend_ichimoku_a', 'volatility_kch', 'volatility_kcl',
        'volatility_dcm', 'trend_ema_slow', 'volatility_bbm', 'trend_ichimoku_base',
        'trend_sma_slow', 'trend_psar_down', 'trend_psar_up', 'volatility_dch',
        'volatility_bbh', 'trend_ichimoku_b', 'volatility_dcl', 'volatility_bbl',
        'trend_visual_ichimoku_a', 'trend_visual_ichimoku_b'
    ]

    gradio_interface()