HoneyTian commited on
Commit
2ebb5f8
·
1 Parent(s): 88b2fbf
examples/conv_tasnet/run.sh CHANGED
@@ -71,9 +71,8 @@ file_dir="${work_dir}/${file_folder_name}"
71
  final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
72
  evaluation_audio_dir="${file_dir}/evaluation_audio"
73
 
74
- dataset="${file_dir}/dataset.xlsx"
75
- train_dataset="${file_dir}/train.xlsx"
76
- valid_dataset="${file_dir}/valid.xlsx"
77
 
78
  $verbose && echo "system_version: ${system_version}"
79
  $verbose && echo "file_folder_name: ${file_folder_name}"
 
71
  final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
72
  evaluation_audio_dir="${file_dir}/evaluation_audio"
73
 
74
+ train_dataset="${file_dir}/train.jsonl"
75
+ valid_dataset="${file_dir}/valid.jsonl"
 
76
 
77
  $verbose && echo "system_version: ${system_version}"
78
  $verbose && echo "file_folder_name: ${file_folder_name}"
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -25,7 +25,7 @@ from torch.nn import functional as F
25
  from torch.utils.data.dataloader import DataLoader
26
  from tqdm import tqdm
27
 
28
- from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
29
  from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
30
  from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
31
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
@@ -125,37 +125,37 @@ def main():
125
  logger.info(f"GPU available count: {n_gpu}; device: {device}")
126
 
127
  # datasets
128
- train_dataset = DenoiseExcelDataset(
129
- excel_file=args.train_dataset,
130
  expected_sample_rate=8000,
131
  max_wave_value=32768.0,
132
  )
133
- valid_dataset = DenoiseExcelDataset(
134
- excel_file=args.valid_dataset,
135
  expected_sample_rate=8000,
136
  max_wave_value=32768.0,
137
  )
138
  train_data_loader = DataLoader(
139
  dataset=train_dataset,
140
  batch_size=config.batch_size,
141
- shuffle=True,
142
  sampler=None,
143
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
144
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
145
  collate_fn=collate_fn,
146
  pin_memory=False,
147
- prefetch_factor=16,
148
  )
149
  valid_data_loader = DataLoader(
150
  dataset=valid_dataset,
151
  batch_size=config.batch_size,
152
- shuffle=True,
153
  sampler=None,
154
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
155
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
156
  collate_fn=collate_fn,
157
  pin_memory=False,
158
- prefetch_factor=16,
159
  )
160
 
161
  # models
 
25
  from torch.utils.data.dataloader import DataLoader
26
  from tqdm import tqdm
27
 
28
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
29
  from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
30
  from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
31
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
 
125
  logger.info(f"GPU available count: {n_gpu}; device: {device}")
126
 
127
  # datasets
128
+ train_dataset = DenoiseJsonlDataset(
129
+ jsonl_file=args.train_dataset,
130
  expected_sample_rate=8000,
131
  max_wave_value=32768.0,
132
  )
133
+ valid_dataset = DenoiseJsonlDataset(
134
+ jsonl_file=args.valid_dataset,
135
  expected_sample_rate=8000,
136
  max_wave_value=32768.0,
137
  )
138
  train_data_loader = DataLoader(
139
  dataset=train_dataset,
140
  batch_size=config.batch_size,
141
+ # shuffle=True,
142
  sampler=None,
143
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
144
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
145
  collate_fn=collate_fn,
146
  pin_memory=False,
147
+ prefetch_factor=2,
148
  )
149
  valid_data_loader = DataLoader(
150
  dataset=valid_dataset,
151
  batch_size=config.batch_size,
152
+ # shuffle=True,
153
  sampler=None,
154
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
155
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
156
  collate_fn=collate_fn,
157
  pin_memory=False,
158
+ prefetch_factor=2,
159
  )
160
 
161
  # models
toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py CHANGED
@@ -2,6 +2,8 @@
2
  # -*- coding: utf-8 -*-
3
  import json
4
  import os
 
 
5
 
6
  import librosa
7
  import numpy as np
@@ -9,28 +11,54 @@ import pandas as pd
9
  from scipy.io import wavfile
10
  import torch
11
  import torchaudio
12
- from torch.utils.data import Dataset
13
  from tqdm import tqdm
14
 
15
 
16
- class DenoiseJsonlDataset(Dataset):
17
  def __init__(self,
18
  jsonl_file: str,
19
  expected_sample_rate: int,
20
  resample: bool = False,
21
  max_wave_value: float = 1.0,
 
22
  ):
23
  self.jsonl_file = jsonl_file
24
  self.expected_sample_rate = expected_sample_rate
25
  self.resample = resample
26
  self.max_wave_value = max_wave_value
27
 
28
- self.samples = self.load_samples(jsonl_file)
 
29
 
30
- @staticmethod
31
- def load_samples(filename: str):
32
- samples = list()
33
- with open(filename, "r", encoding="utf-8") as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  for row in f:
35
  row = json.loads(row)
36
  noise_filename = row["noise_filename"]
@@ -58,11 +86,10 @@ class DenoiseJsonlDataset(Dataset):
58
 
59
  "snr_db": snr_db,
60
  }
61
- samples.append(row)
62
- return samples
63
 
64
- def __getitem__(self, index):
65
- sample = self.samples[index]
66
  noise_filename = sample["noise_filename"]
67
  noise_offset = sample["noise_offset"]
68
  noise_duration = sample["noise_duration"]
@@ -92,9 +119,6 @@ class DenoiseJsonlDataset(Dataset):
92
  }
93
  return result
94
 
95
- def __len__(self):
96
- return len(self.samples)
97
-
98
  def filename_to_waveform(self, filename: str, offset: float, duration: float):
99
  try:
100
  waveform, sample_rate = librosa.load(
@@ -129,5 +153,5 @@ class DenoiseJsonlDataset(Dataset):
129
  return noisy_signal, noise_adjusted
130
 
131
 
132
- if __name__ == '__main__':
133
  pass
 
2
  # -*- coding: utf-8 -*-
3
  import json
4
  import os
5
+ import random
6
+ from typing import List
7
 
8
  import librosa
9
  import numpy as np
 
11
  from scipy.io import wavfile
12
  import torch
13
  import torchaudio
14
+ from torch.utils.data import Dataset, IterableDataset
15
  from tqdm import tqdm
16
 
17
 
18
+ class DenoiseJsonlDataset(IterableDataset):
19
  def __init__(self,
20
  jsonl_file: str,
21
  expected_sample_rate: int,
22
  resample: bool = False,
23
  max_wave_value: float = 1.0,
24
+ buffer_size: int = 1000,
25
  ):
26
  self.jsonl_file = jsonl_file
27
  self.expected_sample_rate = expected_sample_rate
28
  self.resample = resample
29
  self.max_wave_value = max_wave_value
30
 
31
+ self.buffer_size = buffer_size
32
+ self.buffer_samples: List[dict] = list()
33
 
34
+ def __iter__(self):
35
+ iterable_source = self.iterable_source()
36
+
37
+ # 初始填充缓冲区
38
+ try:
39
+ for _ in range(self.buffer_size):
40
+ self.buffer_samples.append(next(iterable_source))
41
+ except StopIteration:
42
+ pass
43
+
44
+ # 动态替换逻辑
45
+ while True:
46
+ try:
47
+ item = next(iterable_source)
48
+ # 随机替换缓冲区元素
49
+ replace_idx = random.randint(0, len(self.buffer_samples) - 1)
50
+ yield self.buffer_samples[replace_idx]
51
+ self.buffer_samples[replace_idx] = item
52
+ except StopIteration:
53
+ break
54
+
55
+ # 清空剩余元素
56
+ random.shuffle(self.buffer_samples)
57
+ for sample in self.buffer_samples:
58
+ yield sample
59
+
60
+ def iterable_source(self):
61
+ with open(self.jsonl_file, "r", encoding="utf-8") as f:
62
  for row in f:
63
  row = json.loads(row)
64
  noise_filename = row["noise_filename"]
 
86
 
87
  "snr_db": snr_db,
88
  }
89
+ sample = self.convert_sample(row)
90
+ yield sample
91
 
92
+ def convert_sample(self, sample: dict):
 
93
  noise_filename = sample["noise_filename"]
94
  noise_offset = sample["noise_offset"]
95
  noise_duration = sample["noise_duration"]
 
119
  }
120
  return result
121
 
 
 
 
122
  def filename_to_waveform(self, filename: str, offset: float, duration: float):
123
  try:
124
  waveform, sample_rate = librosa.load(
 
153
  return noisy_signal, noise_adjusted
154
 
155
 
156
+ if __name__ == "__main__":
157
  pass