Spaces:
Build error
Build error
import gradio as gr | |
import json | |
from datetime import datetime | |
from pathlib import Path | |
from uuid import uuid4 | |
import json | |
import time | |
import os | |
from huggingface_hub import CommitScheduler | |
from functools import partial | |
import pandas as pd | |
import numpy as np | |
from huggingface_hub import snapshot_download | |
import librosa | |
import random | |
def enable_buttons_side_by_side(): | |
return tuple(gr.update(visible=True, interactive=True) for i in range(6)) | |
def disable_buttons_side_by_side(): | |
return tuple(gr.update(visible=i>=4, interactive=False) for i in range(6)) | |
os.makedirs('data', exist_ok = True) | |
LOG_FILENAME = os.path.join('data', f'log_{datetime.now().isoformat()}.json') | |
FLAG_FILENAME = os.path.join('data', f'flagged_{datetime.now().isoformat()}.json') | |
enable_btn = gr.update(interactive=True, visible=True) | |
disable_btn = gr.update(interactive=False) | |
invisible_btn = gr.update(interactive=False, visible=False) | |
no_change_btn = gr.update(value="No Change", interactive=True, visible=True) | |
DS_ID = os.getenv('DS_ID') | |
TOKEN = os.getenv('TOKEN') | |
SONG_SOURCE = os.getenv("SONG_SOURCE") | |
LOCAL_DIR = './' | |
snapshot_download(repo_id=SONG_SOURCE, repo_type="dataset", token = TOKEN, local_dir = LOCAL_DIR) | |
scheduler = CommitScheduler( | |
repo_id= DS_ID, | |
repo_type="dataset", | |
folder_path= os.path.dirname(LOG_FILENAME), | |
path_in_repo="data", | |
token = TOKEN, | |
every = 10, | |
) | |
df = pd.read_csv(os.path.join(LOCAL_DIR,'data.csv')) | |
filenames = list(os.path.join(LOCAL_DIR, 'songs') + '/' + df.filename + '.mp3') | |
indices = list(df.index) | |
main_indices = indices.copy() | |
def init_indices(): | |
global indices, main_indices | |
indices = main_indices | |
def pick_and_remove_one(): | |
global indices | |
if len(indices) < 1: | |
init_indices() | |
np.random.shuffle(indices) | |
sel_indices = indices[0] | |
indices = indices[1:] | |
return sel_indices | |
def vote_last_response(state, vote_type, request: gr.Request): | |
with scheduler.lock: | |
with open(LOG_FILENAME, "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
def flag_last_response(state, vote_type, request: gr.Request): | |
with scheduler.lock: | |
with open(FLAG_FILENAME, "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"state": state.dict(), | |
"ip": get_ip(request), | |
} | |
fout.write(json.dumps(data) + "\n") | |
class AudioStateIG: | |
def __init__(self, row): | |
self.conv_id = uuid4().hex | |
self.row = row | |
self.new_duration = None | |
def dict(self): | |
base = { | |
"conv_id": self.conv_id, | |
"label": self.row.label, | |
"filename": self.row.filename, | |
"duration": self.row.duration if self.new_duration is None else self.new_duration, | |
"song_id": str(self.row.id), | |
"source": self.row.source, | |
"algorithm": self.row.algorithm, | |
} | |
return base | |
def update_duration(self, duration): | |
self.new_duration = duration | |
def get_ip(request: gr.Request): | |
if request: | |
if "cf-connecting-ip" in request.headers: | |
ip = request.headers["cf-connecting-ip"] or request.client.host | |
else: | |
ip = request.client.host | |
else: | |
ip = None | |
return ip | |
def get_song(idx, df = df, filenames = filenames): | |
global indices | |
row = df.loc[idx] | |
audio_path = filenames[idx] | |
state = AudioStateIG(row) | |
#print(df.loc[indices].label.value_counts()) | |
return state, audio_path | |
def random_cut_length(audio_data, max_length, sample_rate): | |
if max_length > 125: | |
options = [125, 55, 25] | |
elif max_length > 55: | |
options = [55, 25] | |
elif max_length > 25: | |
options = [25] | |
else: | |
return audio_data, max_length | |
length_picked = random.choice(options) | |
start_point = np.random.randint(0, max_length - length_picked) | |
end_point = start_point + length_picked | |
audio_data_cut = audio_data[start_point*sample_rate : end_point*sample_rate] | |
return audio_data_cut, length_picked | |
def constant_cut_length(audio_data, max_length, sample_rate, length_picked = 25): | |
if max_length <= length_picked: | |
return audio_data, max_length | |
start_point = np.random.randint(0, max_length - length_picked) | |
end_point = start_point + length_picked | |
audio_data_cut = audio_data[start_point*sample_rate : end_point*sample_rate] | |
return audio_data_cut, length_picked | |
def generate_songs(state, song_cut_function = constant_cut_length): | |
idx= pick_and_remove_one() | |
state, audio = get_song(idx) | |
if song_cut_function is not None: | |
audio_data, sample_rate = librosa.load(audio, sr=None) | |
audio_cut, new_length = song_cut_function(audio_data, state.row.duration, sample_rate) | |
state.update_duration(new_length) | |
return state, (sample_rate, audio_cut), "Vote to Reveal Label", | |
return state, audio, "Vote to Reveal Label", | |
def fake_last_response( | |
state, request: gr.Request | |
): | |
vote_last_response( | |
state, "fake", request | |
) | |
markdown_text = f"### {state.row.label}" | |
if state.row.label != 'real': | |
markdown_text += f"\nModel : {state.row.algorithm}" | |
return (disable_btn,) * 2 + (gr.Markdown(markdown_text, visible=True),) | |
def real_last_response( | |
state, request: gr.Request | |
): | |
vote_last_response( | |
state, "real", request | |
) | |
markdown_text = f"### {state.row.label}" | |
if state.row.label != 'real': | |
markdown_text += f"\nModel : {state.row.algorithm}" | |
return (disable_btn,) * 2 + (gr.Markdown(markdown_text, visible=True),) | |