Spaces:
Running
on
L40S
Running
on
L40S
delete useless message
Browse files- app.py +3 -1
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/README.md +0 -65
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml +1 -1
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/mert_dataset.py +9 -52
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/chroma_torch.py +0 -12
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/mert_model.py +0 -4
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/mert_pretraining.py +2 -35
app.py
CHANGED
@@ -66,12 +66,14 @@ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_co
|
|
66 |
# format lyric
|
67 |
lyric = lyric.replace("[intro]", "[intro-short]").replace("[inst]", "[inst-short]").replace("[outro]", "[outro-short]")
|
68 |
paragraphs = [p.strip() for p in lyric.strip().split('\n\n') if p.strip()]
|
|
|
|
|
69 |
paragraphs_norm = []
|
70 |
for para in paragraphs:
|
71 |
lines = para.splitlines()
|
72 |
struct_tag = lines[0].strip().lower()
|
73 |
if struct_tag not in STRUCTS:
|
74 |
-
return None, json.dumps(f"
|
75 |
if struct_tag in ['[verse]', '[chorus]', '[bridge]']:
|
76 |
if len(lines) < 2 or not [line.strip() for line in lines[1:] if line.strip()]:
|
77 |
return None, json.dumps("The following segments require lyrics: [verse], [chorus], [bridge]")
|
|
|
66 |
# format lyric
|
67 |
lyric = lyric.replace("[intro]", "[intro-short]").replace("[inst]", "[inst-short]").replace("[outro]", "[outro-short]")
|
68 |
paragraphs = [p.strip() for p in lyric.strip().split('\n\n') if p.strip()]
|
69 |
+
if len(paragraphs) < 1:
|
70 |
+
return None, json.dumps("Lyrics can not be left blank")
|
71 |
paragraphs_norm = []
|
72 |
for para in paragraphs:
|
73 |
lines = para.splitlines()
|
74 |
struct_tag = lines[0].strip().lower()
|
75 |
if struct_tag not in STRUCTS:
|
76 |
+
return None, json.dumps(f"Segments should start with a structure tag in {STRUCTS}")
|
77 |
if struct_tag in ['[verse]', '[chorus]', '[bridge]']:
|
78 |
if len(lines) < 2 or not [line.strip() for line in lines[1:] if line.strip()]:
|
79 |
return None, json.dumps("The following segments require lyrics: [verse], [chorus], [bridge]")
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/README.md
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
# Our MERT & BEST-RQ
|
2 |
-
Our implementation on MERT model. Files modified:
|
3 |
-
- mert_fairseq/models/mert/mert_model.py
|
4 |
-
- mert_fairseq/data/mert_dataset.py
|
5 |
-
- run_training_mulNodes_wotorchdist_womodelparsize.sh
|
6 |
-
|
7 |
-
# Prepare
|
8 |
-
|
9 |
-
The MERT training is implemented with [fairseq](https://github.com/pytorch/fairseq). You need to clone the fairseq repo inside our repo at ./src/fairseq and MERT implementation codes as a fairseq example projcet.
|
10 |
-
|
11 |
-
You can do that by following the steps:
|
12 |
-
```
|
13 |
-
mkdir -c ./src/fairseq
|
14 |
-
cd ./src
|
15 |
-
git clone https://github.com/pytorch/fairseq
|
16 |
-
```
|
17 |
-
|
18 |
-
|
19 |
-
# Docker
|
20 |
-
```
|
21 |
-
mirrors.tencent.com/cloudezhou/mert:v3
|
22 |
-
```
|
23 |
-
|
24 |
-
# Start
|
25 |
-
|
26 |
-
### 1-node training
|
27 |
-
|
28 |
-
```
|
29 |
-
bash run_training_sglNodes.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes_debug1node
|
30 |
-
```
|
31 |
-
|
32 |
-
|
33 |
-
### 1-node training (BEST-RQ)
|
34 |
-
|
35 |
-
```
|
36 |
-
bash run_training_sglNodes.sh 0 dummy MERT_RVQ-VAE_CQT_95M_bestrq
|
37 |
-
```
|
38 |
-
|
39 |
-
### 4-node training
|
40 |
-
```
|
41 |
-
bash run_training_mulNodes_wotorchdist_womodelparsize.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes
|
42 |
-
bash run_training_mulNodes_wotorchdist_womodelparsize.sh 1 dummy MERT_RVQ-VAE_CQT_330M_multinodes
|
43 |
-
bash run_training_mulNodes_wotorchdist_womodelparsize.sh 2 dummy MERT_RVQ-VAE_CQT_330M_multinodes
|
44 |
-
bash run_training_mulNodes_wotorchdist_womodelparsize.sh 3 dummy MERT_RVQ-VAE_CQT_330M_multinodes
|
45 |
-
```
|
46 |
-
|
47 |
-
|
48 |
-
### 4-node training (BEST-RQ)
|
49 |
-
```
|
50 |
-
bash run_training_mulNodes_wotorchdist_womodelparsize.sh $INDEX dummy MERT_RVQ-VAE_CQT_95M_bestrq_multinodes BEST_RQ $CHIEF_IP
|
51 |
-
```
|
52 |
-
|
53 |
-
### 4-node training (MusicFM)
|
54 |
-
```
|
55 |
-
bash run_training_mulNodes_wotorchdist_womodelparsize.sh $INDEX dummy MusicFM_95M_multinodes MUSICFM $CHIEF_IP
|
56 |
-
```
|
57 |
-
|
58 |
-
### 4-node training (EAT)
|
59 |
-
```
|
60 |
-
bash run_training_eat.sh $INDEX dummy EAT_pretraining_music_multinodes EAT $CHIEF_IP
|
61 |
-
```
|
62 |
-
|
63 |
-
You could set the parameters in [mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml](mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml)
|
64 |
-
|
65 |
-
Our latest checkpoints is loaded at [data/fairseq_savedir/ckpt_MERT_RVQ-VAE_CQT/MERT_RVQ-VAE_CQT_330M/checkpoint_last.pt](data/fairseq_savedir/ckpt_MERT_RVQ-VAE_CQT/MERT_RVQ-VAE_CQT_330M/checkpoint_last.pt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml
CHANGED
@@ -18,7 +18,7 @@ checkpoint:
|
|
18 |
|
19 |
task:
|
20 |
_name: mae_image_pretraining
|
21 |
-
data:
|
22 |
rebuild_batches: true
|
23 |
key: source
|
24 |
precompute_mask_config: {}
|
|
|
18 |
|
19 |
task:
|
20 |
_name: mae_image_pretraining
|
21 |
+
data: unbalanced_train
|
22 |
rebuild_batches: true
|
23 |
key: source
|
24 |
precompute_mask_config: {}
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/mert_dataset.py
CHANGED
@@ -274,11 +274,6 @@ class MERTDataset(FairseqDataset):
|
|
274 |
dataset_len:int = 128*3000,
|
275 |
clip_secs = 5,
|
276 |
):
|
277 |
-
# self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
|
278 |
-
# manifest_path, max_keep_sample_size, min_keep_sample_size
|
279 |
-
# )
|
280 |
-
|
281 |
-
# manifest_path = '/apdcephfs_cq2/share_1297902/speech_user/erichtchen/shixisheng/zhouyz/MERT/music_data/all_v4/train.json'
|
282 |
self.sample_rate = sample_rate
|
283 |
self.shuffle = shuffle
|
284 |
self.random_crop = random_crop
|
@@ -308,15 +303,8 @@ class MERTDataset(FairseqDataset):
|
|
308 |
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
309 |
else:
|
310 |
self.label_paths = label_paths
|
311 |
-
|
312 |
-
# load_label_offset(p, inds, tot) for p in label_paths
|
313 |
-
# ]
|
314 |
assert label_processors is None or len(label_processors) == self.num_labels
|
315 |
-
# logger.info('skip verify labels and audio lengths')
|
316 |
-
# for label_path, label_rate in zip(label_paths, self.label_rates):
|
317 |
-
# verify_label_lengths(
|
318 |
-
# self.sizes, sample_rate, label_path, label_rate, inds, tot
|
319 |
-
# )
|
320 |
|
321 |
self.max_sample_size = (
|
322 |
max_sample_size if max_sample_size is not None else sys.maxsize
|
@@ -330,9 +318,6 @@ class MERTDataset(FairseqDataset):
|
|
330 |
|
331 |
self.augmentation_effects = augmentation_effects
|
332 |
self.augmentation_probs = augmentation_probs
|
333 |
-
# if len(self.augmentation_effects) > 0:
|
334 |
-
# self.augmentor_init()
|
335 |
-
# self.apply_augmentation = self.augmentation_factry(sample_rate)
|
336 |
|
337 |
self.inbatch_noise_augment_len_range = inbatch_noise_augment_len_range
|
338 |
self.inbatch_noise_augment_number_range = inbatch_noise_augment_number_range
|
@@ -397,10 +382,7 @@ class MERTDataset(FairseqDataset):
|
|
397 |
|
398 |
return augmented_audio
|
399 |
def get_audio_by_slice(self,index):
|
400 |
-
|
401 |
-
# wav_path = os.path.join('/apdcephfs/share_1316500/cloudezhou/MERT/MERT/converted', self.audio_names[index])
|
402 |
wav_path = self.datas[index]['path']
|
403 |
-
# print(wav_path)
|
404 |
audio_info = torchaudio.info(wav_path)
|
405 |
origin_sample_rate = audio_info.sample_rate
|
406 |
origin_duration = audio_info.num_frames / origin_sample_rate
|
@@ -408,32 +390,14 @@ class MERTDataset(FairseqDataset):
|
|
408 |
wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
|
409 |
wav = wav.float()
|
410 |
|
411 |
-
# _path, slice_ptr = parse_path(wav_path) #这个应该也要改
|
412 |
-
# original way
|
413 |
-
# if len(slice_ptr) == 0:
|
414 |
-
# wav, cur_sample_rate = sf.read(_path)
|
415 |
-
# else:
|
416 |
-
# assert _path.endswith(".zip")
|
417 |
-
# data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
418 |
-
# f = io.BytesIO(data)
|
419 |
-
# wav, cur_sample_rate = sf.read(f)
|
420 |
-
# wav = torch.from_numpy(wav).float()
|
421 |
-
# print(wav.shape)
|
422 |
wav = wav.permute(1,0)
|
423 |
wav = self.postprocess(wav, self.sample_rate) #降至单个声道,确认采样率,归一化
|
424 |
-
# print(wav.shape)
|
425 |
-
|
426 |
-
# wav = wav.squeeze(0)
|
427 |
return wav
|
|
|
428 |
def get_audio(self, index):
|
429 |
import soundfile as sf
|
430 |
-
|
431 |
-
|
432 |
-
wav_path = os.path.join('/apdcephfs/share_1316500/cloudezhou/MERT/MERT/converted', self.audio_names[index])
|
433 |
-
# print(wav_path)
|
434 |
-
# self.reader()
|
435 |
-
_path, slice_ptr = parse_path(wav_path) #这个应该也要改
|
436 |
-
# original way
|
437 |
if len(slice_ptr) == 0:
|
438 |
wav, cur_sample_rate = sf.read(_path)
|
439 |
else:
|
@@ -448,8 +412,6 @@ class MERTDataset(FairseqDataset):
|
|
448 |
return wav
|
449 |
|
450 |
def get_label(self, index, label_idx):
|
451 |
-
#label_idx 表示第label_idx个字典,默认8个
|
452 |
-
|
453 |
if self.store_labels and (not self.npmemmap):
|
454 |
label = self.label_list[label_idx][index]
|
455 |
elif self.store_labels and self.npmemmap:
|
@@ -570,11 +532,6 @@ class MERTDataset(FairseqDataset):
|
|
570 |
cqt_labels = self.encoder_cqt_model(collated_audios.float(), forward_type='compute_cqt')
|
571 |
|
572 |
for i, _ in enumerate(audios):
|
573 |
-
# compute cqt labels in advance
|
574 |
-
# cqt_labels
|
575 |
-
|
576 |
-
# yizhilll: apply audio augmentation effects here
|
577 |
-
# the audio should be as the type torch.Tensor, in the shape [1, length] TODO?
|
578 |
if len(self.augmentation_effects) > 0:
|
579 |
with torch.no_grad():
|
580 |
for effect, prob in zip(self.augmentation_effects, self.augmentation_probs):
|
@@ -597,12 +554,12 @@ class MERTDataset(FairseqDataset):
|
|
597 |
|
598 |
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
599 |
assert label_rate > 0
|
600 |
-
s2f = label_rate / self.sample_rate
|
601 |
-
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
602 |
-
frm_size = int(round(audio_size * s2f))
|
603 |
if not self.pad_audio:
|
604 |
-
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
605 |
-
frm_size = min(frm_size, *rem_size)
|
606 |
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
607 |
logger.debug(f"audio_starts={audio_starts}")
|
608 |
logger.debug(f"frame_starts={frm_starts}")
|
|
|
274 |
dataset_len:int = 128*3000,
|
275 |
clip_secs = 5,
|
276 |
):
|
|
|
|
|
|
|
|
|
|
|
277 |
self.sample_rate = sample_rate
|
278 |
self.shuffle = shuffle
|
279 |
self.random_crop = random_crop
|
|
|
303 |
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
304 |
else:
|
305 |
self.label_paths = label_paths
|
306 |
+
|
|
|
|
|
307 |
assert label_processors is None or len(label_processors) == self.num_labels
|
|
|
|
|
|
|
|
|
|
|
308 |
|
309 |
self.max_sample_size = (
|
310 |
max_sample_size if max_sample_size is not None else sys.maxsize
|
|
|
318 |
|
319 |
self.augmentation_effects = augmentation_effects
|
320 |
self.augmentation_probs = augmentation_probs
|
|
|
|
|
|
|
321 |
|
322 |
self.inbatch_noise_augment_len_range = inbatch_noise_augment_len_range
|
323 |
self.inbatch_noise_augment_number_range = inbatch_noise_augment_number_range
|
|
|
382 |
|
383 |
return augmented_audio
|
384 |
def get_audio_by_slice(self,index):
|
|
|
|
|
385 |
wav_path = self.datas[index]['path']
|
|
|
386 |
audio_info = torchaudio.info(wav_path)
|
387 |
origin_sample_rate = audio_info.sample_rate
|
388 |
origin_duration = audio_info.num_frames / origin_sample_rate
|
|
|
390 |
wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
|
391 |
wav = wav.float()
|
392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
wav = wav.permute(1,0)
|
394 |
wav = self.postprocess(wav, self.sample_rate) #降至单个声道,确认采样率,归一化
|
|
|
|
|
|
|
395 |
return wav
|
396 |
+
|
397 |
def get_audio(self, index):
|
398 |
import soundfile as sf
|
399 |
+
wav_path = self.audio_names[index]
|
400 |
+
_path, slice_ptr = parse_path(wav_path)
|
|
|
|
|
|
|
|
|
|
|
401 |
if len(slice_ptr) == 0:
|
402 |
wav, cur_sample_rate = sf.read(_path)
|
403 |
else:
|
|
|
412 |
return wav
|
413 |
|
414 |
def get_label(self, index, label_idx):
|
|
|
|
|
415 |
if self.store_labels and (not self.npmemmap):
|
416 |
label = self.label_list[label_idx][index]
|
417 |
elif self.store_labels and self.npmemmap:
|
|
|
532 |
cqt_labels = self.encoder_cqt_model(collated_audios.float(), forward_type='compute_cqt')
|
533 |
|
534 |
for i, _ in enumerate(audios):
|
|
|
|
|
|
|
|
|
|
|
535 |
if len(self.augmentation_effects) > 0:
|
536 |
with torch.no_grad():
|
537 |
for effect, prob in zip(self.augmentation_effects, self.augmentation_probs):
|
|
|
554 |
|
555 |
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
556 |
assert label_rate > 0
|
557 |
+
s2f = label_rate / self.sample_rate
|
558 |
+
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
559 |
+
frm_size = int(round(audio_size * s2f))
|
560 |
if not self.pad_audio:
|
561 |
+
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
562 |
+
frm_size = min(frm_size, *rem_size)
|
563 |
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
564 |
logger.debug(f"audio_starts={audio_starts}")
|
565 |
logger.debug(f"frame_starts={frm_starts}")
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/chroma_torch.py
CHANGED
@@ -247,15 +247,3 @@ class ChromaSpectrogram(torch.nn.Module):
|
|
247 |
chroma_spectrogram[chroma_spectrogram < 0] = 0.0
|
248 |
chroma_spectrogram = torch.nn.functional.normalize(chroma_spectrogram, p=2, dim=-2)
|
249 |
return chroma_spectrogram
|
250 |
-
|
251 |
-
if __name__ == '__main__':
|
252 |
-
import numpy as np
|
253 |
-
import librosa
|
254 |
-
audio_path = 'speech_data/pretrain/music_42/226849998.flac'
|
255 |
-
sr = 24000
|
256 |
-
freq = 75
|
257 |
-
hop = int(sr // freq)
|
258 |
-
y, _sr = librosa.load(audio_path, duration=5, sr=sr)
|
259 |
-
|
260 |
-
chroma_extractor = ChromaSpectrogram(sample_rate=sr, hop_length=hop, n_fft=2048, use_cqt=True)
|
261 |
-
chroma_tr = chroma_extractor(torch.from_numpy(y)).numpy()
|
|
|
247 |
chroma_spectrogram[chroma_spectrogram < 0] = 0.0
|
248 |
chroma_spectrogram = torch.nn.functional.normalize(chroma_spectrogram, p=2, dim=-2)
|
249 |
return chroma_spectrogram
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/mert_model.py
CHANGED
@@ -1293,10 +1293,8 @@ class MERTModel(BaseFairseqModel):
|
|
1293 |
feat_tsz = features.size(2)
|
1294 |
targ_tsz = min([t.size(1) for t in target_list])
|
1295 |
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
1296 |
-
# @yizhilll: if feature * 2 > 3000, then crop the features
|
1297 |
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
1298 |
features = features[..., :feat_tsz]
|
1299 |
-
# @yizhilll: select only the first pseoudo label if there are multiple labels
|
1300 |
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
1301 |
target_list = [t[:, target_inds.long()] for t in target_list]
|
1302 |
return features, target_list
|
@@ -1507,8 +1505,6 @@ class MERTModel(BaseFairseqModel):
|
|
1507 |
|
1508 |
if not self.skip_masked:
|
1509 |
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
1510 |
-
|
1511 |
-
# @yizhilll: TODO merge the codes heredui
|
1512 |
if self.random_codebook <= 0:
|
1513 |
proj_x_m = self.final_proj(x[masked_indices]) #将特征投射到一个更低维的空间
|
1514 |
if self.untie_final_proj:
|
|
|
1293 |
feat_tsz = features.size(2)
|
1294 |
targ_tsz = min([t.size(1) for t in target_list])
|
1295 |
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
|
|
1296 |
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
1297 |
features = features[..., :feat_tsz]
|
|
|
1298 |
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
1299 |
target_list = [t[:, target_inds.long()] for t in target_list]
|
1300 |
return features, target_list
|
|
|
1505 |
|
1506 |
if not self.skip_masked:
|
1507 |
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
|
|
|
|
1508 |
if self.random_codebook <= 0:
|
1509 |
proj_x_m = self.final_proj(x[masked_indices]) #将特征投射到一个更低维的空间
|
1510 |
if self.untie_final_proj:
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/mert_pretraining.py
CHANGED
@@ -20,8 +20,7 @@ from fairseq.tasks import register_task
|
|
20 |
from fairseq.tasks.fairseq_task import FairseqTask
|
21 |
from omegaconf import MISSING
|
22 |
|
23 |
-
|
24 |
-
from ..data.mert_dataset import MERTDataset #这么做感觉有大问题,得换个办法
|
25 |
from ..data.ark_dataset import ArkDataset
|
26 |
|
27 |
logger = logging.getLogger(__name__)
|
@@ -32,8 +31,6 @@ class LabelEncoder(object):
|
|
32 |
self.dictionary = dictionary
|
33 |
|
34 |
def __call__(self, label: str) -> List[str]:
|
35 |
-
# @yizhilll: https://fairseq.readthedocs.io/en/latest/_modules/fairseq/data/dictionary.html \
|
36 |
-
# encode_line return a torch.IntTensor, should be all 1 for vanila HuBERT
|
37 |
return self.dictionary.encode_line(
|
38 |
label,
|
39 |
append_eos=False,
|
@@ -45,17 +42,6 @@ class PaddedNumpyLabelEncoder(object):
|
|
45 |
pass
|
46 |
|
47 |
def __call__(self, label):
|
48 |
-
# @yizhilll: https://fairseq.readthedocs.io/en/latest/_modules/fairseq/data/dictionary.html \
|
49 |
-
# encode_line return a torch.IntTensor, should be all 1 for vanila HuBERT
|
50 |
-
# return self.dictionary.encode_line(
|
51 |
-
# label,
|
52 |
-
# append_eos=False,
|
53 |
-
# add_if_not_exist=False,
|
54 |
-
# )
|
55 |
-
# if isisntance(label, np.memmap):
|
56 |
-
|
57 |
-
# assert isisntance(label, np.memmap)
|
58 |
-
# t = torch.IntTensor(np.asarray(label).copy())
|
59 |
t = torch.IntTensor(np.asarray(label))
|
60 |
t = t[t>=0] # remove padded -1 values at the end
|
61 |
return t
|
@@ -262,9 +248,7 @@ class MERTPretrainingTask(FairseqTask):
|
|
262 |
else:
|
263 |
self.state.add_factory("dictionaries", self.load_dictionaries)
|
264 |
|
265 |
-
self.blank_symbol = "<s>"
|
266 |
-
|
267 |
-
# @yizhilll: use eval() to pass list parameters, skirt the fairseq/torch error: Can't pickle <enum 'Choices'>: attribute lookup Choices on fairseq.dataclass.constants failed
|
268 |
self.augmentation_effects = eval(self.cfg.augmentation_effects)
|
269 |
self.augmentation_probs = eval(self.cfg.augmentation_probs)
|
270 |
if len(self.augmentation_effects) > 0:
|
@@ -321,14 +305,6 @@ class MERTPretrainingTask(FairseqTask):
|
|
321 |
return self.cfg.data
|
322 |
return self.cfg.label_dir
|
323 |
|
324 |
-
# def has_sharded_data(self, split):
|
325 |
-
# """overwrite this function for let the trainier do dataset reload for changing the the dynamic croppings"""
|
326 |
-
# logger.info(f"check whether to re-load dataset for epoch {epoch} by overwritting task.has_sharded_data()")
|
327 |
-
# # find the threshold that holds epoch \in [threshold, next_threshold)
|
328 |
-
# is_reload_dataset = epoch in self.dynamic_crops_epoches
|
329 |
-
|
330 |
-
# return os.pathsep in getattr(self.cfg, "data", "") or is_reload_dataset
|
331 |
-
# def is_force_load_dataset(self, epoch):
|
332 |
def is_force_load_dataset(self, epoch, training_restore=False):
|
333 |
# find the threshold that holds epoch \in [threshold, next_threshold)
|
334 |
return (epoch in self.dynamic_crops_epoches) or training_restore or (self.cfg.sharding_data > 1)
|
@@ -340,15 +316,6 @@ class MERTPretrainingTask(FairseqTask):
|
|
340 |
|
341 |
def set_dynamic_crop_max_sample(self, epoch):
|
342 |
""" force to set the max_sample_size config for the dynamic cropping function"""
|
343 |
-
# pass
|
344 |
-
# @yizhilll: the parameter "epoch" is passed into this funciton in trainer.py#688,
|
345 |
-
# containing in "**kwargs"
|
346 |
-
# if 'train' in split:
|
347 |
-
# epoch = kwargs['epoch']
|
348 |
-
|
349 |
-
# find the threshold that holds epoch \in [threshold, next_threshold)
|
350 |
-
# is_reload_dataset = epoch in self.dynamic_crops_epoches # test again
|
351 |
-
# if is_reload_dataset:
|
352 |
if epoch in self.dynamic_crops_epoches:
|
353 |
for idx in range(len(self.dynamic_crops_epoches)):
|
354 |
if (idx == len(self.dynamic_crops_epoches)-1) or \
|
|
|
20 |
from fairseq.tasks.fairseq_task import FairseqTask
|
21 |
from omegaconf import MISSING
|
22 |
|
23 |
+
from ..data.mert_dataset import MERTDataset
|
|
|
24 |
from ..data.ark_dataset import ArkDataset
|
25 |
|
26 |
logger = logging.getLogger(__name__)
|
|
|
31 |
self.dictionary = dictionary
|
32 |
|
33 |
def __call__(self, label: str) -> List[str]:
|
|
|
|
|
34 |
return self.dictionary.encode_line(
|
35 |
label,
|
36 |
append_eos=False,
|
|
|
42 |
pass
|
43 |
|
44 |
def __call__(self, label):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
t = torch.IntTensor(np.asarray(label))
|
46 |
t = t[t>=0] # remove padded -1 values at the end
|
47 |
return t
|
|
|
248 |
else:
|
249 |
self.state.add_factory("dictionaries", self.load_dictionaries)
|
250 |
|
251 |
+
self.blank_symbol = "<s>"
|
|
|
|
|
252 |
self.augmentation_effects = eval(self.cfg.augmentation_effects)
|
253 |
self.augmentation_probs = eval(self.cfg.augmentation_probs)
|
254 |
if len(self.augmentation_effects) > 0:
|
|
|
305 |
return self.cfg.data
|
306 |
return self.cfg.label_dir
|
307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
def is_force_load_dataset(self, epoch, training_restore=False):
|
309 |
# find the threshold that holds epoch \in [threshold, next_threshold)
|
310 |
return (epoch in self.dynamic_crops_epoches) or training_restore or (self.cfg.sharding_data > 1)
|
|
|
316 |
|
317 |
def set_dynamic_crop_max_sample(self, epoch):
|
318 |
""" force to set the max_sample_size config for the dynamic cropping function"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
if epoch in self.dynamic_crops_epoches:
|
320 |
for idx in range(len(self.dynamic_crops_epoches)):
|
321 |
if (idx == len(self.dynamic_crops_epoches)-1) or \
|