import gradio as gr
from huggingface_hub import from_pretrained_keras
import pandas as pd
import numpy as np
import json
from matplotlib import pyplot as plt

f = open('scaler.json')
scaler = json.load(f)

TIME_STEPS = 288

# Generated training sequences for use in the model.
def create_sequences(values, time_steps=TIME_STEPS):
    output = []
    for i in range(len(values) - time_steps + 1):
        output.append(values[i : (i + time_steps)])
    return np.stack(output)


def normalize_data(data):
    df_test_value = (data - scaler["mean"]) / scaler["std"]
    return df_test_value

def plot_test_data(df_test_value):
    fig, ax = plt.subplots(figsize=(12, 6))
    df_test_value.plot(legend=False, ax=ax)
    ax.set_xlabel("Time")
    ax.set_ylabel("Value")
    ax.set_title("Input Test Data")
    return fig

def get_anomalies(df_test_value):
    # Create sequences from test values.
    x_test = create_sequences(df_test_value.values)
    model = from_pretrained_keras("keras-io/timeseries-anomaly-detection")

    # Get test MAE loss.
    x_test_pred = model.predict(x_test)
    test_mae_loss = np.mean(np.abs(x_test_pred - x_test), axis=1)
    test_mae_loss = test_mae_loss.reshape((-1))

    # Detect all the samples which are anomalies.
    anomalies = test_mae_loss > scaler["threshold"]
    return anomalies

def plot_anomalies(df_test_value, data, anomalies):
    # data i is an anomaly if samples [(i - timesteps + 1) to (i)] are anomalies
    anomalous_data_indices = []
    for data_idx in range(TIME_STEPS - 1, len(df_test_value) - TIME_STEPS + 1):
        if np.all(anomalies[data_idx - TIME_STEPS + 1 : data_idx]):
            anomalous_data_indices.append(data_idx)
    df_subset = data.iloc[anomalous_data_indices]
    fig, ax = plt.subplots(figsize=(12, 6))
    data.plot(legend=False, ax=ax)
    df_subset.plot(legend=False, ax=ax, color="r")
    ax.set_xlabel("Time")
    ax.set_ylabel("Value")
    ax.set_title("Anomalous Data Points")
    return fig

def clean_data(df):
    # Check if the DataFrame already contains the correct columns
    if "timestamp" in df.columns and "value" in df.columns:
        df["timestamp"] = pd.to_datetime(df["timestamp"])
        return df

    # Check if DataFrame contains the columns to be converted
    elif "Date" in df.columns and "Hour" in df.columns and "Hourly_Labor_Hours_Total" in df.columns:
        # Convert "Date" and "Hour" columns into datetime format
        df["timestamp"] = pd.to_datetime(df["Date"]) + pd.to_timedelta(df["Hour"].astype(int), unit='h')

        # Handle the case where hour is 24
        df.loc[df["timestamp"].dt.hour == 24, "timestamp"] = df["timestamp"] + pd.DateOffset(days=1)
        df["timestamp"] = df["timestamp"].dt.floor('h')

        # Keep only necessary columns
        df = df[["timestamp", "Hourly_Labor_Hours_Total"]]

        # Rename column
        df.rename(columns={"Hourly_Labor_Hours_Total": "value"}, inplace=True)

        return df

    elif "Date_CY" in df.columns and "Hour" in df.columns and "Net_Sales_CY" in df.columns:
        # Convert "Date_CY" and "Hour" columns into datetime format

        df = df.dropna(subset=['Date_CY', 'Hour', 'Net_Sales_CY'])

        df["timestamp"] = pd.to_datetime(df["Date_CY"]) + pd.to_timedelta(df["Hour"].astype(int), unit='h')

        # Handle the case where hour is 24
        df.loc[df["timestamp"].dt.hour == 24, "timestamp"] = df["timestamp"] - pd.DateOffset(days=1)
        df["timestamp"] = df["timestamp"].dt.floor('h')

        # Keep only necessary columns
        df = df[["timestamp", "Net_Sales_CY"]]

        # Rename column
        df.rename(columns={"Net_Sales_CY": "value"}, inplace=True)

        return df

    else:
        raise ValueError("Dataframe does not contain necessary columns.")

def master(file):
    data = pd.read_csv(file.name)
    print(f"Original data shape: {data.shape}")  # Debug statement
    data = clean_data(data)
    print(f"Cleaned data shape: {data.shape}")  # Debug statement
    data['timestamp'] = pd.to_datetime(data['timestamp'])
    data.set_index("timestamp", inplace=True)
    if len(data) < TIME_STEPS:
        fig, ax = plt.subplots(figsize=(8, 5))
        ax.text(0.5, 0.5, "Not enough data to create sequences. Need at least {} records.".format(TIME_STEPS),
                horizontalalignment='center', verticalalignment='center', fontsize=14)
        plt.axis('off')
        return fig
    df_test_value = normalize_data(data)
    plot1 = plot_test_data(df_test_value)
    anomalies = get_anomalies(df_test_value)
    plot2 = plot_anomalies(df_test_value, data, anomalies)
    return plot2

outputs = gr.outputs.Image()

iface = gr.Interface(
    fn=master,
    inputs=gr.inputs.File(label="CSV File"),
    outputs=outputs,
    examples=["art_daily_jumpsup.csv","labor_hourly_short.csv", "sales_hourly_short.csv"],
    title="Timeseries Anomaly Detection Using an Autoencoder",
    description="Anomaly detection of timeseries data."
)

iface.launch()