# import sklearn
import gradio as gr
# import joblib
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from PIL import Image
# import datasets

# pipe = joblib.load("./model.pkl")

title = "RegMix: Data Mixture as Regression for Language Model Pre-training"
description = "We propose a regression-based method to find high-performing data mixture for language model pre-training."

def infer(inputs, additional_inputs):
    df = pd.DataFrame(inputs, columns=headers)
    
    X_columns = df.columns[0:-1]
    y_column = df.columns[-1]

    df_train, df_val = train_test_split(df, test_size=0.125, random_state=42)

    hyper_params = {
        'task': 'train',
        'boosting_type': 'gbdt',
        'objective': 'regression',
        'metric': ['l1','l2'],
        "num_iterations": 1000, 
        'seed': 42,
        'learning_rate': 1e-2,
    }

    target = df_train[y_column]
    eval_target = df_val[y_column]
        
    np.random.seed(42)

    gbm = lgb.LGBMRegressor(**hyper_params)

    reg = gbm.fit(df_train[X_columns].values, target,
        eval_set=[(df_val[X_columns].values, eval_target)],
        eval_metric='l2',
    callbacks=[
        lgb.early_stopping(stopping_rounds=3),
    ]
        )
    
    predictions = reg.predict(df_val[X_columns].values)
    df_val['Prediction'] = predictions

    ####
    import matplotlib.pyplot as plt
    plt.rcParams["font.family"] = "Times New Roman" # !!!!
    plt.rcParams.update({'font.size': 24})
    plt.rcParams.update({'axes.labelpad': 20})

    from matplotlib import cm
    from matplotlib.ticker import LinearLocator

    fig, ax = plt.subplots(figsize=(12, 12), layout='compressed', subplot_kw={"projection": "3d"})

    stride = 0.025
    X = np.arange(0, 1+stride, stride)
    Y = np.arange(0, 1+stride, stride)

    X, Y = np.meshgrid(X, Y)
    Z = []
    for (x,y) in zip(X.reshape(-1), Y.reshape(-1)):
        if (x+y)>1:
            Z.append(np.inf)
        else:
            Z.append(
                reg.predict(np.asarray([x, y, 1-x-y]).reshape(1, -1)
                                        )[0])
    Z = np.asarray(Z).reshape(len(np.arange(0, 1+stride, stride)), len(np.arange(0, 1+stride, stride)))

    # Plot the surface.
    surf = ax.plot_surface(X, Y, Z, 
                        edgecolor='white', 
                        lw=0.5, rstride=2, cstride=2,
                    alpha=0.85,
                        cmap='coolwarm', 
                        vmin=min(Z[Z!=np.inf]),
                        vmax=max(Z[Z!=np.inf]),
                        # linewidth=8, 
                        antialiased=False, )

    ax.zaxis.set_major_locator(LinearLocator(10))
    ax.zaxis.set_major_formatter('{x:.02f}')

    ax.view_init(elev=25, azim=45, roll=0) #####

    ax.contourf(X, Y, Z, zdir='z', 
                    offset=np.min(Z)-0.35, 
                    cmap=cm.coolwarm)
    
    from matplotlib.patches import Circle
    from mpl_toolkits.mplot3d import art3d

    def add_point(ax, x, y, z, fc = None, ec = None, radius = 0.005):
        xy_len, z_len = ax.get_figure().get_size_inches()
        axis_length = [x[1] - x[0] for x in [ax.get_xbound(), ax.get_ybound(), ax.get_zbound()]]
        axis_rotation =  {'z': ((x, y, z), axis_length[1]/axis_length[0]),
                            'y': ((x, z, y), axis_length[2]/axis_length[0]*xy_len/z_len),
                            'x': ((y, z, x), axis_length[2]/axis_length[1]*xy_len/z_len)}
        for a, ((x0, y0, z0), ratio) in axis_rotation.items():
            p = Circle((x0, y0), radius, lw=1.5,
                        # width = radius, height = radius*ratio, 
                        fc=fc,
                        ec=ec)
            ax.add_patch(p)
            art3d.pathpatch_2d_to_3d(p, z=z0, zdir=a)

    
    add_point(ax, X.reshape(-1)[np.argmin(Z)], Y.reshape(-1)[np.argmin(Z)], np.min(Z), 
            fc='Red', 
            ec='Red', radius=0.015)

    add_point(ax, X.reshape(-1)[np.argmin(Z)], Y.reshape(-1)[np.argmin(Z)], np.min(Z)-0.35, 
            fc='Red', 
            ec='Red', radius=0.015)


    ax.set_xlabel('Github (%)', fontdict={
        'size':24
    })
    ax.set_ylabel('Hacker News (%)', fontdict={
        'size':24
    })

    ax.set_xticks(np.arange(0, 1, 0.2), [str(np.round(num, 1)) for num in np.arange(0, 100, 20)], )
    ax.set_yticks(np.arange(0, 1, 0.2), [str(np.round(num, 1)) for num in np.arange(0, 100, 20)], )

    ax.set_zticks(np.arange(np.min(Z), np.max(Z[Z!=np.inf]), 0.2), [str(np.round(num, 1)) for num in np.arange(np.min(Z), np.max(Z[Z!=np.inf]), 0.2)], )

    ax.zaxis.labelpad=1

    ax.set_zlim(np.min(Z)-0.35, max(Z[Z!=np.inf])+0.01)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_box_aspect(aspect=None, zoom=0.775)

    ax.zaxis._axinfo['juggled'] = (1,2,2)

    # Add a color bar which maps values to colors.
    cbar = fig.colorbar(surf, 
                shrink=0.5, 
                aspect=25, pad=0.01
                )
    cbar.ax.set_ylabel('Prediction', fontdict={
        'size':32
    }, 
                    # rotation=270, 
                    # labelpad=-90
                    )


    filename = "tmp.png"
    plt.savefig(filename, bbox_inches='tight', pad_inches=0.1)
    ####
    return [gr.ScatterPlot(
            value=df_val,
            x="Prediction",
            y="Target",
            title="Scatter",
            tooltip=["Prediction", "Target"],
            x_lim=[min(min(predictions), min(df_val[y_column]))-0.25, max(max(predictions), max(df_val[y_column]))+0.25],
            y_lim=[min(min(predictions), min(df_val[y_column]))-0.25, max(max(predictions), max(df_val[y_column]))+0.25]
        ), 
        gr.Image(Image.open('tmp.png')),
        df_val[['Target', 'Prediction']], ]

def upload_csv(file):
    df = pd.read_csv(file.name, 
                    #  encoding='utf-8'
                     )
    # Return as formatted string
    # print(df.head())
    return df

df = pd.read_csv('data.csv')
headers = df.columns.tolist()

inputs = [gr.Dataframe(headers=headers, row_count = (8, "dynamic"), datatype='number', col_count=(4,"fixed"), label="Dataset", interactive=1)]
outputs = [gr.ScatterPlot(), gr.Image(), gr.Dataframe(row_count = (2, "dynamic"), col_count=(2, "fixed"), datatype='number', label="Results", headers=["Target", "Prediction"])]

with gr.Blocks() as demo:

    ####
    upload_button = gr.UploadButton(label="Upload", file_types = ['.csv'], 
                                    # live=True, 
                                    file_count = "single", render=False)    
    upload_button.upload(fn=upload_csv, inputs=upload_button, outputs=inputs, api_name="upload_csv")
    ####

    gr.Interface(infer, inputs=inputs, outputs=outputs, title=title,
                additional_inputs = [upload_button], 
                additional_inputs_accordion='Upload CSV',
                description = description, 
                examples=[[df], []], 
                cache_examples=False, allow_flagging='never')
    

demo.launch(debug=False)