bishmoy commited on
Commit
d4dba14
·
verified ·
1 Parent(s): 3a4d071

create single song utils

Browse files
Files changed (1) hide show
  1. utils.py +145 -0
utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from uuid import uuid4
6
+ import json
7
+ import time
8
+ import os
9
+ from huggingface_hub import CommitScheduler
10
+ from functools import partial
11
+ import pandas as pd
12
+ import numpy as np
13
+ from huggingface_hub import snapshot_download
14
+
15
+ def enable_buttons_side_by_side():
16
+ return tuple(gr.update(visible=True, interactive=True) for i in range(6))
17
+
18
+ def disable_buttons_side_by_side():
19
+ return tuple(gr.update(visible=i>=4, interactive=False) for i in range(6))
20
+
21
+
22
+ os.makedirs('data', exist_ok = True)
23
+ LOG_FILENAME = os.path.join('data', f'log_{datetime.now().isoformat()}.json')
24
+ FLAG_FILENAME = os.path.join('data', f'flagged_{datetime.now().isoformat()}.json')
25
+
26
+ enable_btn = gr.update(interactive=True, visible=True)
27
+ disable_btn = gr.update(interactive=False)
28
+ invisible_btn = gr.update(interactive=False, visible=False)
29
+ no_change_btn = gr.update(value="No Change", interactive=True, visible=True)
30
+
31
+ DS_ID = os.getenv('DS_ID')
32
+ TOKEN = os.getenv('TOKEN')
33
+ SONG_SOURCE = os.getenv("SONG_SOURCE")
34
+ LOCAL_DIR = './'
35
+
36
+ snapshot_download(repo_id=SONG_SOURCE, repo_type="dataset", token = TOKEN, local_dir = LOCAL_DIR)
37
+
38
+ scheduler = CommitScheduler(
39
+ repo_id= DS_ID,
40
+ repo_type="dataset",
41
+ folder_path= os.path.dirname(LOG_FILENAME),
42
+ path_in_repo="data",
43
+ token = TOKEN,
44
+ every = 10,
45
+ )
46
+
47
+ df = pd.read_csv(os.path.join(LOCAL_DIR,'data.csv'))
48
+ filenames = list(os.path.join(LOCAL_DIR, 'songs') + '/' + df.filename + '.mp3')
49
+
50
+ indices = list(df.index)
51
+ main_indices = indices.copy()
52
+
53
+ def init_indices():
54
+ global indices, main_indices
55
+ indices = main_indices
56
+
57
+
58
+ def pick_and_remove_one():
59
+ global indices
60
+ if len(indices) < 1:
61
+ init_indices()
62
+
63
+ np.random.shuffle(indices)
64
+ sel_indices = indices[0]
65
+ indices = indices[1:]
66
+ print("Indices : ",sel_indices)
67
+ return sel_indices
68
+
69
+
70
+ def vote_last_response(state, vote_type, request: gr.Request):
71
+ with scheduler.lock:
72
+ with open(LOG_FILENAME, "a") as fout:
73
+ data = {
74
+ "tstamp": round(time.time(), 4),
75
+ "type": vote_type,
76
+ "state": state.dict(),
77
+ "ip": get_ip(request),
78
+ }
79
+ fout.write(json.dumps(data) + "\n")
80
+
81
+ def flag_last_response(state, vote_type, request: gr.Request):
82
+ with scheduler.lock:
83
+ with open(FLAG_FILENAME, "a") as fout:
84
+ data = {
85
+ "tstamp": round(time.time(), 4),
86
+ "type": vote_type,
87
+ "state": state.dict(),
88
+ "ip": get_ip(request),
89
+ }
90
+ fout.write(json.dumps(data) + "\n")
91
+
92
+
93
+ class AudioStateIG:
94
+ def __init__(self, row):
95
+ self.conv_id = uuid4().hex
96
+ self.row = row
97
+
98
+ def dict(self):
99
+ base = {
100
+ "conv_id": self.conv_id,
101
+ "label": self.row.label,
102
+ "filename": self.row.filename
103
+ }
104
+ return base
105
+
106
+ def get_ip(request: gr.Request):
107
+ if request:
108
+ if "cf-connecting-ip" in request.headers:
109
+ ip = request.headers["cf-connecting-ip"] or request.client.host
110
+ else:
111
+ ip = request.client.host
112
+ else:
113
+ ip = None
114
+ return ip
115
+
116
+
117
+ def get_song(idx, df = df, filenames = filenames):
118
+ row = df.loc[idx]
119
+ audio_path = filenames[idx]
120
+ state = AudioStateIG(row)
121
+ return state, audio_path
122
+
123
+ def generate_songs(state):
124
+ idx= pick_and_remove_one()
125
+ state, audio = get_song(idx)
126
+
127
+
128
+ return state, audio, "Vote to Reveal Label",
129
+
130
+ def fake_last_response(
131
+ state, request: gr.Request
132
+ ):
133
+ vote_last_response(
134
+ state, "fake", request
135
+ )
136
+ return (disable_btn,) * 3 + (gr.Markdown(f"### {state.row.label}", visible=True),)
137
+
138
+ def real_last_response(
139
+ state, request: gr.Request
140
+ ):
141
+ vote_last_response(
142
+ state, "real", request
143
+ )
144
+ return (disable_btn,) * 3 + (gr.Markdown(f"### {state.row.label}", visible=True),)
145
+