waytan22 commited on
Commit
f7c5fc0
·
1 Parent(s): 3c8f8cf

delete useless message

Browse files
app.py CHANGED
@@ -66,12 +66,14 @@ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_co
66
  # format lyric
67
  lyric = lyric.replace("[intro]", "[intro-short]").replace("[inst]", "[inst-short]").replace("[outro]", "[outro-short]")
68
  paragraphs = [p.strip() for p in lyric.strip().split('\n\n') if p.strip()]
 
 
69
  paragraphs_norm = []
70
  for para in paragraphs:
71
  lines = para.splitlines()
72
  struct_tag = lines[0].strip().lower()
73
  if struct_tag not in STRUCTS:
74
- return None, json.dumps(f"segments should start with a structure tag in {STRUCTS}")
75
  if struct_tag in ['[verse]', '[chorus]', '[bridge]']:
76
  if len(lines) < 2 or not [line.strip() for line in lines[1:] if line.strip()]:
77
  return None, json.dumps("The following segments require lyrics: [verse], [chorus], [bridge]")
 
66
  # format lyric
67
  lyric = lyric.replace("[intro]", "[intro-short]").replace("[inst]", "[inst-short]").replace("[outro]", "[outro-short]")
68
  paragraphs = [p.strip() for p in lyric.strip().split('\n\n') if p.strip()]
69
+ if len(paragraphs) < 1:
70
+ return None, json.dumps("Lyrics can not be left blank")
71
  paragraphs_norm = []
72
  for para in paragraphs:
73
  lines = para.splitlines()
74
  struct_tag = lines[0].strip().lower()
75
  if struct_tag not in STRUCTS:
76
+ return None, json.dumps(f"Segments should start with a structure tag in {STRUCTS}")
77
  if struct_tag in ['[verse]', '[chorus]', '[bridge]']:
78
  if len(lines) < 2 or not [line.strip() for line in lines[1:] if line.strip()]:
79
  return None, json.dumps("The following segments require lyrics: [verse], [chorus], [bridge]")
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/README.md DELETED
@@ -1,65 +0,0 @@
1
- # Our MERT & BEST-RQ
2
- Our implementation on MERT model. Files modified:
3
- - mert_fairseq/models/mert/mert_model.py
4
- - mert_fairseq/data/mert_dataset.py
5
- - run_training_mulNodes_wotorchdist_womodelparsize.sh
6
-
7
- # Prepare
8
-
9
- The MERT training is implemented with [fairseq](https://github.com/pytorch/fairseq). You need to clone the fairseq repo inside our repo at ./src/fairseq and MERT implementation codes as a fairseq example projcet.
10
-
11
- You can do that by following the steps:
12
- ```
13
- mkdir -c ./src/fairseq
14
- cd ./src
15
- git clone https://github.com/pytorch/fairseq
16
- ```
17
-
18
-
19
- # Docker
20
- ```
21
- mirrors.tencent.com/cloudezhou/mert:v3
22
- ```
23
-
24
- # Start
25
-
26
- ### 1-node training
27
-
28
- ```
29
- bash run_training_sglNodes.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes_debug1node
30
- ```
31
-
32
-
33
- ### 1-node training (BEST-RQ)
34
-
35
- ```
36
- bash run_training_sglNodes.sh 0 dummy MERT_RVQ-VAE_CQT_95M_bestrq
37
- ```
38
-
39
- ### 4-node training
40
- ```
41
- bash run_training_mulNodes_wotorchdist_womodelparsize.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes
42
- bash run_training_mulNodes_wotorchdist_womodelparsize.sh 1 dummy MERT_RVQ-VAE_CQT_330M_multinodes
43
- bash run_training_mulNodes_wotorchdist_womodelparsize.sh 2 dummy MERT_RVQ-VAE_CQT_330M_multinodes
44
- bash run_training_mulNodes_wotorchdist_womodelparsize.sh 3 dummy MERT_RVQ-VAE_CQT_330M_multinodes
45
- ```
46
-
47
-
48
- ### 4-node training (BEST-RQ)
49
- ```
50
- bash run_training_mulNodes_wotorchdist_womodelparsize.sh $INDEX dummy MERT_RVQ-VAE_CQT_95M_bestrq_multinodes BEST_RQ $CHIEF_IP
51
- ```
52
-
53
- ### 4-node training (MusicFM)
54
- ```
55
- bash run_training_mulNodes_wotorchdist_womodelparsize.sh $INDEX dummy MusicFM_95M_multinodes MUSICFM $CHIEF_IP
56
- ```
57
-
58
- ### 4-node training (EAT)
59
- ```
60
- bash run_training_eat.sh $INDEX dummy EAT_pretraining_music_multinodes EAT $CHIEF_IP
61
- ```
62
-
63
- You could set the parameters in [mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml](mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml)
64
-
65
- Our latest checkpoints is loaded at [data/fairseq_savedir/ckpt_MERT_RVQ-VAE_CQT/MERT_RVQ-VAE_CQT_330M/checkpoint_last.pt](data/fairseq_savedir/ckpt_MERT_RVQ-VAE_CQT/MERT_RVQ-VAE_CQT_330M/checkpoint_last.pt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml CHANGED
@@ -18,7 +18,7 @@ checkpoint:
18
 
19
  task:
20
  _name: mae_image_pretraining
21
- data: /hpc_stor03/sjtu_home/wenxi.chen/mydata/audio/unbalanced_train
22
  rebuild_batches: true
23
  key: source
24
  precompute_mask_config: {}
 
18
 
19
  task:
20
  _name: mae_image_pretraining
21
+ data: unbalanced_train
22
  rebuild_batches: true
23
  key: source
24
  precompute_mask_config: {}
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/mert_dataset.py CHANGED
@@ -274,11 +274,6 @@ class MERTDataset(FairseqDataset):
274
  dataset_len:int = 128*3000,
275
  clip_secs = 5,
276
  ):
277
- # self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
278
- # manifest_path, max_keep_sample_size, min_keep_sample_size
279
- # )
280
-
281
- # manifest_path = '/apdcephfs_cq2/share_1297902/speech_user/erichtchen/shixisheng/zhouyz/MERT/music_data/all_v4/train.json'
282
  self.sample_rate = sample_rate
283
  self.shuffle = shuffle
284
  self.random_crop = random_crop
@@ -308,15 +303,8 @@ class MERTDataset(FairseqDataset):
308
  self.label_list = [load_label(p, inds, tot) for p in label_paths]
309
  else:
310
  self.label_paths = label_paths
311
- # self.label_offsets_list = [
312
- # load_label_offset(p, inds, tot) for p in label_paths
313
- # ]
314
  assert label_processors is None or len(label_processors) == self.num_labels
315
- # logger.info('skip verify labels and audio lengths')
316
- # for label_path, label_rate in zip(label_paths, self.label_rates):
317
- # verify_label_lengths(
318
- # self.sizes, sample_rate, label_path, label_rate, inds, tot
319
- # )
320
 
321
  self.max_sample_size = (
322
  max_sample_size if max_sample_size is not None else sys.maxsize
@@ -330,9 +318,6 @@ class MERTDataset(FairseqDataset):
330
 
331
  self.augmentation_effects = augmentation_effects
332
  self.augmentation_probs = augmentation_probs
333
- # if len(self.augmentation_effects) > 0:
334
- # self.augmentor_init()
335
- # self.apply_augmentation = self.augmentation_factry(sample_rate)
336
 
337
  self.inbatch_noise_augment_len_range = inbatch_noise_augment_len_range
338
  self.inbatch_noise_augment_number_range = inbatch_noise_augment_number_range
@@ -397,10 +382,7 @@ class MERTDataset(FairseqDataset):
397
 
398
  return augmented_audio
399
  def get_audio_by_slice(self,index):
400
-
401
- # wav_path = os.path.join('/apdcephfs/share_1316500/cloudezhou/MERT/MERT/converted', self.audio_names[index])
402
  wav_path = self.datas[index]['path']
403
- # print(wav_path)
404
  audio_info = torchaudio.info(wav_path)
405
  origin_sample_rate = audio_info.sample_rate
406
  origin_duration = audio_info.num_frames / origin_sample_rate
@@ -408,32 +390,14 @@ class MERTDataset(FairseqDataset):
408
  wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
409
  wav = wav.float()
410
 
411
- # _path, slice_ptr = parse_path(wav_path) #这个应该也要改
412
- # original way
413
- # if len(slice_ptr) == 0:
414
- # wav, cur_sample_rate = sf.read(_path)
415
- # else:
416
- # assert _path.endswith(".zip")
417
- # data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
418
- # f = io.BytesIO(data)
419
- # wav, cur_sample_rate = sf.read(f)
420
- # wav = torch.from_numpy(wav).float()
421
- # print(wav.shape)
422
  wav = wav.permute(1,0)
423
  wav = self.postprocess(wav, self.sample_rate) #降至单个声道,确认采样率,归一化
424
- # print(wav.shape)
425
-
426
- # wav = wav.squeeze(0)
427
  return wav
 
428
  def get_audio(self, index):
429
  import soundfile as sf
430
-
431
- # wav_path = os.path.join(self.audio_root, self.audio_names[index])
432
- wav_path = os.path.join('/apdcephfs/share_1316500/cloudezhou/MERT/MERT/converted', self.audio_names[index])
433
- # print(wav_path)
434
- # self.reader()
435
- _path, slice_ptr = parse_path(wav_path) #这个应该也要改
436
- # original way
437
  if len(slice_ptr) == 0:
438
  wav, cur_sample_rate = sf.read(_path)
439
  else:
@@ -448,8 +412,6 @@ class MERTDataset(FairseqDataset):
448
  return wav
449
 
450
  def get_label(self, index, label_idx):
451
- #label_idx 表示第label_idx个字典,默认8个
452
-
453
  if self.store_labels and (not self.npmemmap):
454
  label = self.label_list[label_idx][index]
455
  elif self.store_labels and self.npmemmap:
@@ -570,11 +532,6 @@ class MERTDataset(FairseqDataset):
570
  cqt_labels = self.encoder_cqt_model(collated_audios.float(), forward_type='compute_cqt')
571
 
572
  for i, _ in enumerate(audios):
573
- # compute cqt labels in advance
574
- # cqt_labels
575
-
576
- # yizhilll: apply audio augmentation effects here
577
- # the audio should be as the type torch.Tensor, in the shape [1, length] TODO?
578
  if len(self.augmentation_effects) > 0:
579
  with torch.no_grad():
580
  for effect, prob in zip(self.augmentation_effects, self.augmentation_probs):
@@ -597,12 +554,12 @@ class MERTDataset(FairseqDataset):
597
 
598
  def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
599
  assert label_rate > 0
600
- s2f = label_rate / self.sample_rate # @yizhilll: 0.00625 for 100Hz and 16k sr
601
- frm_starts = [int(round(s * s2f)) for s in audio_starts] # @yizhilll: should be all 0 if the audios are not croped
602
- frm_size = int(round(audio_size * s2f)) # @yizhilll: this is the expected total number of given pseudo labels
603
  if not self.pad_audio:
604
- rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] # @yizhilll: what does this mean?
605
- frm_size = min(frm_size, *rem_size) # @yizhilll: anyway, this should keep 3000 for 30s audio
606
  targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
607
  logger.debug(f"audio_starts={audio_starts}")
608
  logger.debug(f"frame_starts={frm_starts}")
 
274
  dataset_len:int = 128*3000,
275
  clip_secs = 5,
276
  ):
 
 
 
 
 
277
  self.sample_rate = sample_rate
278
  self.shuffle = shuffle
279
  self.random_crop = random_crop
 
303
  self.label_list = [load_label(p, inds, tot) for p in label_paths]
304
  else:
305
  self.label_paths = label_paths
306
+
 
 
307
  assert label_processors is None or len(label_processors) == self.num_labels
 
 
 
 
 
308
 
309
  self.max_sample_size = (
310
  max_sample_size if max_sample_size is not None else sys.maxsize
 
318
 
319
  self.augmentation_effects = augmentation_effects
320
  self.augmentation_probs = augmentation_probs
 
 
 
321
 
322
  self.inbatch_noise_augment_len_range = inbatch_noise_augment_len_range
323
  self.inbatch_noise_augment_number_range = inbatch_noise_augment_number_range
 
382
 
383
  return augmented_audio
384
  def get_audio_by_slice(self,index):
 
 
385
  wav_path = self.datas[index]['path']
 
386
  audio_info = torchaudio.info(wav_path)
387
  origin_sample_rate = audio_info.sample_rate
388
  origin_duration = audio_info.num_frames / origin_sample_rate
 
390
  wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
391
  wav = wav.float()
392
 
 
 
 
 
 
 
 
 
 
 
 
393
  wav = wav.permute(1,0)
394
  wav = self.postprocess(wav, self.sample_rate) #降至单个声道,确认采样率,归一化
 
 
 
395
  return wav
396
+
397
  def get_audio(self, index):
398
  import soundfile as sf
399
+ wav_path = self.audio_names[index]
400
+ _path, slice_ptr = parse_path(wav_path)
 
 
 
 
 
401
  if len(slice_ptr) == 0:
402
  wav, cur_sample_rate = sf.read(_path)
403
  else:
 
412
  return wav
413
 
414
  def get_label(self, index, label_idx):
 
 
415
  if self.store_labels and (not self.npmemmap):
416
  label = self.label_list[label_idx][index]
417
  elif self.store_labels and self.npmemmap:
 
532
  cqt_labels = self.encoder_cqt_model(collated_audios.float(), forward_type='compute_cqt')
533
 
534
  for i, _ in enumerate(audios):
 
 
 
 
 
535
  if len(self.augmentation_effects) > 0:
536
  with torch.no_grad():
537
  for effect, prob in zip(self.augmentation_effects, self.augmentation_probs):
 
554
 
555
  def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
556
  assert label_rate > 0
557
+ s2f = label_rate / self.sample_rate
558
+ frm_starts = [int(round(s * s2f)) for s in audio_starts]
559
+ frm_size = int(round(audio_size * s2f))
560
  if not self.pad_audio:
561
+ rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
562
+ frm_size = min(frm_size, *rem_size)
563
  targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
564
  logger.debug(f"audio_starts={audio_starts}")
565
  logger.debug(f"frame_starts={frm_starts}")
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/chroma_torch.py CHANGED
@@ -247,15 +247,3 @@ class ChromaSpectrogram(torch.nn.Module):
247
  chroma_spectrogram[chroma_spectrogram < 0] = 0.0
248
  chroma_spectrogram = torch.nn.functional.normalize(chroma_spectrogram, p=2, dim=-2)
249
  return chroma_spectrogram
250
-
251
- if __name__ == '__main__':
252
- import numpy as np
253
- import librosa
254
- audio_path = 'speech_data/pretrain/music_42/226849998.flac'
255
- sr = 24000
256
- freq = 75
257
- hop = int(sr // freq)
258
- y, _sr = librosa.load(audio_path, duration=5, sr=sr)
259
-
260
- chroma_extractor = ChromaSpectrogram(sample_rate=sr, hop_length=hop, n_fft=2048, use_cqt=True)
261
- chroma_tr = chroma_extractor(torch.from_numpy(y)).numpy()
 
247
  chroma_spectrogram[chroma_spectrogram < 0] = 0.0
248
  chroma_spectrogram = torch.nn.functional.normalize(chroma_spectrogram, p=2, dim=-2)
249
  return chroma_spectrogram
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/mert_model.py CHANGED
@@ -1293,10 +1293,8 @@ class MERTModel(BaseFairseqModel):
1293
  feat_tsz = features.size(2)
1294
  targ_tsz = min([t.size(1) for t in target_list])
1295
  if self.feat2tar_ratio * feat_tsz > targ_tsz:
1296
- # @yizhilll: if feature * 2 > 3000, then crop the features
1297
  feat_tsz = int(targ_tsz / self.feat2tar_ratio)
1298
  features = features[..., :feat_tsz]
1299
- # @yizhilll: select only the first pseoudo label if there are multiple labels
1300
  target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
1301
  target_list = [t[:, target_inds.long()] for t in target_list]
1302
  return features, target_list
@@ -1507,8 +1505,6 @@ class MERTModel(BaseFairseqModel):
1507
 
1508
  if not self.skip_masked:
1509
  masked_indices = torch.logical_and(~padding_mask, mask_indices)
1510
-
1511
- # @yizhilll: TODO merge the codes heredui
1512
  if self.random_codebook <= 0:
1513
  proj_x_m = self.final_proj(x[masked_indices]) #将特征投射到一个更低维的空间
1514
  if self.untie_final_proj:
 
1293
  feat_tsz = features.size(2)
1294
  targ_tsz = min([t.size(1) for t in target_list])
1295
  if self.feat2tar_ratio * feat_tsz > targ_tsz:
 
1296
  feat_tsz = int(targ_tsz / self.feat2tar_ratio)
1297
  features = features[..., :feat_tsz]
 
1298
  target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
1299
  target_list = [t[:, target_inds.long()] for t in target_list]
1300
  return features, target_list
 
1505
 
1506
  if not self.skip_masked:
1507
  masked_indices = torch.logical_and(~padding_mask, mask_indices)
 
 
1508
  if self.random_codebook <= 0:
1509
  proj_x_m = self.final_proj(x[masked_indices]) #将特征投射到一个更低维的空间
1510
  if self.untie_final_proj:
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/mert_pretraining.py CHANGED
@@ -20,8 +20,7 @@ from fairseq.tasks import register_task
20
  from fairseq.tasks.fairseq_task import FairseqTask
21
  from omegaconf import MISSING
22
 
23
- # from ..data.mert_dataset import MERTDataset
24
- from ..data.mert_dataset import MERTDataset #这么做感觉有大问题,得换个办法
25
  from ..data.ark_dataset import ArkDataset
26
 
27
  logger = logging.getLogger(__name__)
@@ -32,8 +31,6 @@ class LabelEncoder(object):
32
  self.dictionary = dictionary
33
 
34
  def __call__(self, label: str) -> List[str]:
35
- # @yizhilll: https://fairseq.readthedocs.io/en/latest/_modules/fairseq/data/dictionary.html \
36
- # encode_line return a torch.IntTensor, should be all 1 for vanila HuBERT
37
  return self.dictionary.encode_line(
38
  label,
39
  append_eos=False,
@@ -45,17 +42,6 @@ class PaddedNumpyLabelEncoder(object):
45
  pass
46
 
47
  def __call__(self, label):
48
- # @yizhilll: https://fairseq.readthedocs.io/en/latest/_modules/fairseq/data/dictionary.html \
49
- # encode_line return a torch.IntTensor, should be all 1 for vanila HuBERT
50
- # return self.dictionary.encode_line(
51
- # label,
52
- # append_eos=False,
53
- # add_if_not_exist=False,
54
- # )
55
- # if isisntance(label, np.memmap):
56
-
57
- # assert isisntance(label, np.memmap)
58
- # t = torch.IntTensor(np.asarray(label).copy())
59
  t = torch.IntTensor(np.asarray(label))
60
  t = t[t>=0] # remove padded -1 values at the end
61
  return t
@@ -262,9 +248,7 @@ class MERTPretrainingTask(FairseqTask):
262
  else:
263
  self.state.add_factory("dictionaries", self.load_dictionaries)
264
 
265
- self.blank_symbol = "<s>"
266
-
267
- # @yizhilll: use eval() to pass list parameters, skirt the fairseq/torch error: Can't pickle <enum 'Choices'>: attribute lookup Choices on fairseq.dataclass.constants failed
268
  self.augmentation_effects = eval(self.cfg.augmentation_effects)
269
  self.augmentation_probs = eval(self.cfg.augmentation_probs)
270
  if len(self.augmentation_effects) > 0:
@@ -321,14 +305,6 @@ class MERTPretrainingTask(FairseqTask):
321
  return self.cfg.data
322
  return self.cfg.label_dir
323
 
324
- # def has_sharded_data(self, split):
325
- # """overwrite this function for let the trainier do dataset reload for changing the the dynamic croppings"""
326
- # logger.info(f"check whether to re-load dataset for epoch {epoch} by overwritting task.has_sharded_data()")
327
- # # find the threshold that holds epoch \in [threshold, next_threshold)
328
- # is_reload_dataset = epoch in self.dynamic_crops_epoches
329
-
330
- # return os.pathsep in getattr(self.cfg, "data", "") or is_reload_dataset
331
- # def is_force_load_dataset(self, epoch):
332
  def is_force_load_dataset(self, epoch, training_restore=False):
333
  # find the threshold that holds epoch \in [threshold, next_threshold)
334
  return (epoch in self.dynamic_crops_epoches) or training_restore or (self.cfg.sharding_data > 1)
@@ -340,15 +316,6 @@ class MERTPretrainingTask(FairseqTask):
340
 
341
  def set_dynamic_crop_max_sample(self, epoch):
342
  """ force to set the max_sample_size config for the dynamic cropping function"""
343
- # pass
344
- # @yizhilll: the parameter "epoch" is passed into this funciton in trainer.py#688,
345
- # containing in "**kwargs"
346
- # if 'train' in split:
347
- # epoch = kwargs['epoch']
348
-
349
- # find the threshold that holds epoch \in [threshold, next_threshold)
350
- # is_reload_dataset = epoch in self.dynamic_crops_epoches # test again
351
- # if is_reload_dataset:
352
  if epoch in self.dynamic_crops_epoches:
353
  for idx in range(len(self.dynamic_crops_epoches)):
354
  if (idx == len(self.dynamic_crops_epoches)-1) or \
 
20
  from fairseq.tasks.fairseq_task import FairseqTask
21
  from omegaconf import MISSING
22
 
23
+ from ..data.mert_dataset import MERTDataset
 
24
  from ..data.ark_dataset import ArkDataset
25
 
26
  logger = logging.getLogger(__name__)
 
31
  self.dictionary = dictionary
32
 
33
  def __call__(self, label: str) -> List[str]:
 
 
34
  return self.dictionary.encode_line(
35
  label,
36
  append_eos=False,
 
42
  pass
43
 
44
  def __call__(self, label):
 
 
 
 
 
 
 
 
 
 
 
45
  t = torch.IntTensor(np.asarray(label))
46
  t = t[t>=0] # remove padded -1 values at the end
47
  return t
 
248
  else:
249
  self.state.add_factory("dictionaries", self.load_dictionaries)
250
 
251
+ self.blank_symbol = "<s>"
 
 
252
  self.augmentation_effects = eval(self.cfg.augmentation_effects)
253
  self.augmentation_probs = eval(self.cfg.augmentation_probs)
254
  if len(self.augmentation_effects) > 0:
 
305
  return self.cfg.data
306
  return self.cfg.label_dir
307
 
 
 
 
 
 
 
 
 
308
  def is_force_load_dataset(self, epoch, training_restore=False):
309
  # find the threshold that holds epoch \in [threshold, next_threshold)
310
  return (epoch in self.dynamic_crops_epoches) or training_restore or (self.cfg.sharding_data > 1)
 
316
 
317
  def set_dynamic_crop_max_sample(self, epoch):
318
  """ force to set the max_sample_size config for the dynamic cropping function"""
 
 
 
 
 
 
 
 
 
319
  if epoch in self.dynamic_crops_epoches:
320
  for idx in range(len(self.dynamic_crops_epoches)):
321
  if (idx == len(self.dynamic_crops_epoches)-1) or \