|
import gradio as gr |
|
import json |
|
from datetime import datetime |
|
from pathlib import Path |
|
from uuid import uuid4 |
|
import json |
|
import time |
|
import os |
|
from huggingface_hub import CommitScheduler |
|
from functools import partial |
|
import pandas as pd |
|
import numpy as np |
|
from huggingface_hub import snapshot_download |
|
import librosa |
|
import random |
|
|
|
def enable_buttons_side_by_side(): |
|
return tuple(gr.update(visible=True, interactive=True) for i in range(6)) |
|
|
|
def disable_buttons_side_by_side(): |
|
return tuple(gr.update(visible=i>=4, interactive=False) for i in range(6)) |
|
|
|
|
|
os.makedirs('data', exist_ok = True) |
|
LOG_FILENAME = os.path.join('data', f'log_{datetime.now().isoformat()}.json') |
|
FLAG_FILENAME = os.path.join('data', f'flagged_{datetime.now().isoformat()}.json') |
|
|
|
enable_btn = gr.update(interactive=True, visible=True) |
|
disable_btn = gr.update(interactive=False) |
|
invisible_btn = gr.update(interactive=False, visible=False) |
|
no_change_btn = gr.update(value="No Change", interactive=True, visible=True) |
|
|
|
DS_ID = os.getenv('DS_ID') |
|
TOKEN = os.getenv('TOKEN') |
|
SONG_SOURCE = os.getenv("SONG_SOURCE") |
|
LOCAL_DIR = './' |
|
|
|
snapshot_download(repo_id=SONG_SOURCE, repo_type="dataset", token = TOKEN, local_dir = LOCAL_DIR) |
|
|
|
scheduler = CommitScheduler( |
|
repo_id= DS_ID, |
|
repo_type="dataset", |
|
folder_path= os.path.dirname(LOG_FILENAME), |
|
path_in_repo="data", |
|
token = TOKEN, |
|
every = 10, |
|
) |
|
|
|
df = pd.read_csv(os.path.join(LOCAL_DIR,'data.csv')) |
|
|
|
filenames = list(os.path.join(LOCAL_DIR, 'songs') + '/' + df.filename + '.mp3') |
|
|
|
indices = list(df.index) |
|
main_indices = indices.copy() |
|
|
|
def init_indices(): |
|
global indices, main_indices |
|
indices = main_indices |
|
|
|
|
|
def pick_and_remove_one(): |
|
global indices |
|
if len(indices) < 1: |
|
init_indices() |
|
|
|
np.random.shuffle(indices) |
|
sel_indices = indices[0] |
|
indices = indices[1:] |
|
return sel_indices |
|
|
|
|
|
def vote_last_response(state, vote_type, request: gr.Request): |
|
with scheduler.lock: |
|
with open(LOG_FILENAME, "a") as fout: |
|
data = { |
|
"tstamp": round(time.time(), 4), |
|
"type": vote_type, |
|
"state": state.dict(), |
|
"ip": get_ip(request), |
|
} |
|
fout.write(json.dumps(data) + "\n") |
|
|
|
def flag_last_response(state, vote_type, request: gr.Request): |
|
with scheduler.lock: |
|
with open(FLAG_FILENAME, "a") as fout: |
|
data = { |
|
"tstamp": round(time.time(), 4), |
|
"type": vote_type, |
|
"state": state.dict(), |
|
"ip": get_ip(request), |
|
} |
|
fout.write(json.dumps(data) + "\n") |
|
|
|
|
|
class AudioStateIG: |
|
def __init__(self, row): |
|
self.conv_id = uuid4().hex |
|
self.row = row |
|
self.new_duration = None |
|
|
|
def dict(self): |
|
base = { |
|
"conv_id": self.conv_id, |
|
"label": self.row.label, |
|
"filename": self.row.filename, |
|
"duration": self.row.duration if self.new_duration is None else self.new_duration, |
|
"song_id": str(self.row.id), |
|
"source": self.row.source, |
|
"algorithm": self.row.algorithm, |
|
} |
|
return base |
|
|
|
def update_duration(self, duration): |
|
self.new_duration = duration |
|
|
|
def get_ip(request: gr.Request): |
|
if request: |
|
if "cf-connecting-ip" in request.headers: |
|
ip = request.headers["cf-connecting-ip"] or request.client.host |
|
else: |
|
ip = request.client.host |
|
else: |
|
ip = None |
|
return ip |
|
|
|
|
|
def get_song(idx, df = df, filenames = filenames): |
|
global indices |
|
|
|
row = df.loc[idx] |
|
audio_path = filenames[idx] |
|
state = AudioStateIG(row) |
|
|
|
|
|
|
|
return state, audio_path |
|
|
|
def random_cut_length(audio_data, max_length, sample_rate): |
|
if max_length > 125: |
|
options = [125, 55, 25] |
|
|
|
elif max_length > 55: |
|
options = [55, 25] |
|
|
|
elif max_length > 25: |
|
options = [25] |
|
|
|
else: |
|
return audio_data, max_length |
|
|
|
|
|
length_picked = random.choice(options) |
|
start_point = np.random.randint(0, max_length - length_picked) |
|
end_point = start_point + length_picked |
|
audio_data_cut = audio_data[start_point*sample_rate : end_point*sample_rate] |
|
return audio_data_cut, length_picked |
|
|
|
def constant_cut_length(audio_data, max_length, sample_rate, length_picked = 25): |
|
if max_length <= length_picked: |
|
return audio_data, max_length |
|
|
|
start_point = np.random.randint(0, max_length - length_picked) |
|
end_point = start_point + length_picked |
|
audio_data_cut = audio_data[start_point*sample_rate : end_point*sample_rate] |
|
return audio_data_cut, length_picked |
|
|
|
def generate_songs(state, song_cut_function = constant_cut_length): |
|
idx= pick_and_remove_one() |
|
state, audio = get_song(idx) |
|
if song_cut_function is not None: |
|
audio_data, sample_rate = librosa.load(audio, sr=None) |
|
audio_cut, new_length = song_cut_function(audio_data, state.row.duration, sample_rate) |
|
state.update_duration(new_length) |
|
return state, (sample_rate, audio_cut), "Vote to Reveal Label", |
|
|
|
return state, audio, "Vote to Reveal Label", |
|
|
|
def fake_last_response( |
|
state, request: gr.Request |
|
): |
|
vote_last_response( |
|
state, "fake", request |
|
) |
|
markdown_text = f"### {state.row.label}" |
|
if state.row.label != 'real': |
|
markdown_text += f"\nModel : {state.row.algorithm}" |
|
|
|
return (disable_btn,) * 2 + (gr.Markdown(markdown_text, visible=True),) |
|
|
|
def real_last_response( |
|
state, request: gr.Request |
|
): |
|
vote_last_response( |
|
state, "real", request |
|
) |
|
|
|
markdown_text = f"### {state.row.label}" |
|
if state.row.label != 'real': |
|
markdown_text += f"\nModel : {state.row.algorithm}" |
|
|
|
return (disable_btn,) * 2 + (gr.Markdown(markdown_text, visible=True),) |
|
|
|
|