import gradio as gr
import json
from datetime import datetime
from pathlib import Path
from uuid import uuid4
import json
import time
import os
from huggingface_hub import CommitScheduler
from functools import partial
import pandas as pd
import numpy as np
from huggingface_hub import snapshot_download
import librosa
import random 

def enable_buttons_side_by_side():
    return tuple(gr.update(visible=True, interactive=True) for i in range(6))

def disable_buttons_side_by_side():
    return tuple(gr.update(visible=i>=4, interactive=False) for i in range(6))


os.makedirs('data', exist_ok = True)
LOG_FILENAME = os.path.join('data', f'log_{datetime.now().isoformat()}.json')
FLAG_FILENAME = os.path.join('data', f'flagged_{datetime.now().isoformat()}.json')

enable_btn = gr.update(interactive=True, visible=True)
disable_btn = gr.update(interactive=False)
invisible_btn = gr.update(interactive=False, visible=False)
no_change_btn = gr.update(value="No Change", interactive=True, visible=True)

DS_ID = os.getenv('DS_ID')
TOKEN = os.getenv('TOKEN')
SONG_SOURCE = os.getenv("SONG_SOURCE")
LOCAL_DIR = './'

snapshot_download(repo_id=SONG_SOURCE, repo_type="dataset", token = TOKEN, local_dir = LOCAL_DIR)

scheduler = CommitScheduler(
    repo_id= DS_ID,
    repo_type="dataset",
    folder_path= os.path.dirname(LOG_FILENAME),
    path_in_repo="data",
    token = TOKEN,
    every = 10,
)

df = pd.read_csv(os.path.join(LOCAL_DIR,'data.csv'))

filenames = list(os.path.join(LOCAL_DIR, 'songs') + '/' + df.filename + '.mp3')

indices = list(df.index)
main_indices = indices.copy()

def init_indices():
    global indices, main_indices
    indices = main_indices


def pick_and_remove_one():
    global indices
    if len(indices) < 1:
        init_indices()

    np.random.shuffle(indices)
    sel_indices = indices[0]
    indices = indices[1:]
    return sel_indices


def vote_last_response(state, vote_type, request: gr.Request):
    with scheduler.lock:
      with open(LOG_FILENAME, "a") as fout:
          data = {
              "tstamp": round(time.time(), 4),
              "type": vote_type,
              "state": state.dict(),
              "ip": get_ip(request),
          }
          fout.write(json.dumps(data) + "\n")

def flag_last_response(state, vote_type, request: gr.Request):
    with scheduler.lock:
      with open(FLAG_FILENAME, "a") as fout:
          data = {
              "tstamp": round(time.time(), 4),
              "type": vote_type,
              "state": state.dict(),
              "ip": get_ip(request),
          }
          fout.write(json.dumps(data) + "\n")


class AudioStateIG:
    def __init__(self, row):
        self.conv_id = uuid4().hex
        self.row = row
        self.new_duration = None

    def dict(self):
        base = {
            "conv_id": self.conv_id,
            "label": self.row.label,
            "filename": self.row.filename,
            "duration": self.row.duration if self.new_duration is None else self.new_duration, 
            "song_id": str(self.row.id), 
            "source": self.row.source, 
            "algorithm": self.row.algorithm,
            }
        return base

    def update_duration(self, duration):
        self.new_duration = duration

def get_ip(request: gr.Request):
    if request:
        if "cf-connecting-ip" in request.headers:
            ip = request.headers["cf-connecting-ip"] or request.client.host
        else:
            ip = request.client.host
    else:
        ip = None
    return ip


def get_song(idx, df = df, filenames = filenames):
    global indices
    
    row = df.loc[idx]
    audio_path = filenames[idx]
    state = AudioStateIG(row)

    #print(df.loc[indices].label.value_counts())
    
    return state, audio_path

def random_cut_length(audio_data, max_length, sample_rate):
    if max_length > 125:
      options = [125, 55, 25]

    elif max_length > 55:
      options = [55, 25]
    
    elif max_length > 25:
      options = [25]

    else:
      return audio_data, max_length

        
    length_picked = random.choice(options)
    start_point = np.random.randint(0, max_length - length_picked)
    end_point = start_point + length_picked
    audio_data_cut = audio_data[start_point*sample_rate : end_point*sample_rate]
    return audio_data_cut, length_picked

def constant_cut_length(audio_data, max_length, sample_rate, length_picked = 25):
    if max_length <= length_picked:
        return audio_data, max_length
        
    start_point = np.random.randint(0, max_length - length_picked)
    end_point = start_point + length_picked
    audio_data_cut = audio_data[start_point*sample_rate : end_point*sample_rate]
    return audio_data_cut, length_picked

def generate_songs(state, song_cut_function = constant_cut_length):
    idx= pick_and_remove_one()
    state, audio = get_song(idx)
    if song_cut_function is not None:
      audio_data, sample_rate = librosa.load(audio, sr=None)
      audio_cut, new_length = song_cut_function(audio_data, state.row.duration, sample_rate)
      state.update_duration(new_length)
      return state, (sample_rate, audio_cut), "Vote to Reveal Label",
        
    return state, audio, "Vote to Reveal Label",
    
def fake_last_response(
    state, request: gr.Request
):
    vote_last_response(
      state, "fake", request
    )
    markdown_text = f"### {state.row.label}"
    if state.row.label != 'real':
        markdown_text += f"\nModel : {state.row.algorithm}"
        
    return (disable_btn,) * 2 + (gr.Markdown(markdown_text, visible=True),)

def real_last_response(
    state, request: gr.Request
):
    vote_last_response(
      state, "real", request
    )
    
    markdown_text = f"### {state.row.label}"
    if state.row.label != 'real':
        markdown_text += f"\nModel : {state.row.algorithm}"
        
    return (disable_btn,) * 2 + (gr.Markdown(markdown_text, visible=True),)