|
import gradio as gr |
|
import json |
|
from datetime import datetime |
|
from pathlib import Path |
|
from uuid import uuid4 |
|
import json |
|
import time |
|
import os |
|
from huggingface_hub import CommitScheduler |
|
from functools import partial |
|
import pandas as pd |
|
import numpy as np |
|
from huggingface_hub import snapshot_download |
|
|
|
def enable_buttons_side_by_side(): |
|
return tuple(gr.update(visible=True, interactive=True) for i in range(6)) |
|
|
|
def disable_buttons_side_by_side(): |
|
return tuple(gr.update(visible=i>=4, interactive=False) for i in range(6)) |
|
|
|
|
|
os.makedirs('data', exist_ok = True) |
|
LOG_FILENAME = os.path.join('data', f'log_{datetime.now().isoformat()}.json') |
|
FLAG_FILENAME = os.path.join('data', f'flagged_{datetime.now().isoformat()}.json') |
|
|
|
enable_btn = gr.update(interactive=True, visible=True) |
|
disable_btn = gr.update(interactive=False) |
|
invisible_btn = gr.update(interactive=False, visible=False) |
|
no_change_btn = gr.update(value="No Change", interactive=True, visible=True) |
|
|
|
DS_ID = os.getenv('DS_ID') |
|
TOKEN = os.getenv('TOKEN') |
|
SONG_SOURCE = os.getenv("SONG_SOURCE") |
|
LOCAL_DIR = './' |
|
|
|
snapshot_download(repo_id=SONG_SOURCE, repo_type="dataset", token = TOKEN, local_dir = LOCAL_DIR) |
|
|
|
scheduler = CommitScheduler( |
|
repo_id= DS_ID, |
|
repo_type="dataset", |
|
folder_path= os.path.dirname(LOG_FILENAME), |
|
path_in_repo="data", |
|
token = TOKEN, |
|
every = 10, |
|
) |
|
|
|
df = pd.read_csv(os.path.join(LOCAL_DIR,'data.csv')) |
|
filenames = list(os.path.join(LOCAL_DIR, 'songs') + '/' + df.filename + '.mp3') |
|
|
|
indices = list(df.index) |
|
main_indices = indices.copy() |
|
|
|
def init_indices(): |
|
global indices, main_indices |
|
indices = main_indices |
|
|
|
|
|
def pick_and_remove_one(): |
|
global indices |
|
if len(indices) < 1: |
|
init_indices() |
|
|
|
np.random.shuffle(indices) |
|
sel_indices = indices[0] |
|
indices = indices[1:] |
|
print("Indices : ",sel_indices) |
|
return sel_indices |
|
|
|
|
|
def vote_last_response(state, vote_type, request: gr.Request): |
|
with scheduler.lock: |
|
with open(LOG_FILENAME, "a") as fout: |
|
data = { |
|
"tstamp": round(time.time(), 4), |
|
"type": vote_type, |
|
"state": state.dict(), |
|
"ip": get_ip(request), |
|
} |
|
fout.write(json.dumps(data) + "\n") |
|
|
|
def flag_last_response(state, vote_type, request: gr.Request): |
|
with scheduler.lock: |
|
with open(FLAG_FILENAME, "a") as fout: |
|
data = { |
|
"tstamp": round(time.time(), 4), |
|
"type": vote_type, |
|
"state": state.dict(), |
|
"ip": get_ip(request), |
|
} |
|
fout.write(json.dumps(data) + "\n") |
|
|
|
|
|
class AudioStateIG: |
|
def __init__(self, row): |
|
self.conv_id = uuid4().hex |
|
self.row = row |
|
|
|
def dict(self): |
|
base = { |
|
"conv_id": self.conv_id, |
|
"label": self.row.label, |
|
"filename": self.row.filename |
|
} |
|
return base |
|
|
|
def get_ip(request: gr.Request): |
|
if request: |
|
if "cf-connecting-ip" in request.headers: |
|
ip = request.headers["cf-connecting-ip"] or request.client.host |
|
else: |
|
ip = request.client.host |
|
else: |
|
ip = None |
|
return ip |
|
|
|
|
|
def get_song(idx, df = df, filenames = filenames): |
|
row = df.loc[idx] |
|
audio_path = filenames[idx] |
|
state = AudioStateIG(row) |
|
return state, audio_path |
|
|
|
def generate_songs(state): |
|
idx= pick_and_remove_one() |
|
state, audio = get_song(idx) |
|
|
|
|
|
return state, audio, "Vote to Reveal Label", |
|
|
|
def fake_last_response( |
|
state, request: gr.Request |
|
): |
|
vote_last_response( |
|
state, "fake", request |
|
) |
|
return (disable_btn,) * 3 + (gr.Markdown(f"### {state.row.label}", visible=True),) |
|
|
|
def real_last_response( |
|
state, request: gr.Request |
|
): |
|
vote_last_response( |
|
state, "real", request |
|
) |
|
return (disable_btn,) * 3 + (gr.Markdown(f"### {state.row.label}", visible=True),) |
|
|
|
|