File size: 3,961 Bytes
64ceedd d8be50a 956e325 d8be50a 64ceedd d8be50a 64ceedd d8be50a 64ceedd d8be50a 64ceedd a34a59d 64ceedd d8be50a 64ceedd d8be50a 64ceedd 67c1496 64ceedd d8be50a 67c1496 d8be50a 64ceedd d8be50a 64ceedd d8be50a 64ceedd b6c45cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# DPTNet_quant_sep.py
import warnings
warnings.filterwarnings("ignore", message="Failed to initialize NumPy: _ARRAY_API not found")
import os
import torch
import numpy as np
import torchaudio
from huggingface_hub import hf_hub_download
# 動態導入 asteroid_test 中的 DPTNet
try:
from . import asteroid_test
except ImportError as e:
raise ImportError("無法載入 asteroid_test 模組,請確認該模組與訓練時相同") from e
torchaudio.set_audio_backend("sox_io")
def get_conf():
"""取得模型參數設定"""
conf_filterbank = {
'n_filters': 64,
'kernel_size': 16,
'stride': 8
}
conf_masknet = {
'in_chan': 64,
'n_src': 2,
'out_chan': 64,
'ff_hid': 256,
'ff_activation': "relu",
'norm_type': "gLN",
'chunk_size': 100,
'hop_size': 50,
'n_repeats': 2,
'mask_act': 'sigmoid',
'bidirectional': True,
'dropout': 0
}
return conf_filterbank, conf_masknet
def load_dpt_model():
print('Load Separation Model...')
speech_sep_token = os.getenv("SpeechSeparation")
if not speech_sep_token:
raise EnvironmentError("環境變數 SpeechSeparation 未設定!")
model_path = hf_hub_download(
repo_id="DeepLearning101/speech-separation",
filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
token=speech_sep_token
)
conf_filterbank, conf_masknet = get_conf()
try:
model_class = getattr(asteroid_test, "DPTNet")
model = model_class(**conf_filterbank, **conf_masknet)
except Exception as e:
raise RuntimeError("模型結構錯誤:請確認 asteroid_test.py 是否與訓練時相同") from e
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.LSTM, torch.nn.Linear},
dtype=torch.qint8
)
state_dict = torch.load(model_path, map_location="cpu")
own_state = model.state_dict()
# 只保留是 torch.Tensor 的 key-value pairs
filtered_state_dict = {}
for k, v in state_dict.items():
if k in own_state:
if isinstance(v, torch.Tensor) and isinstance(own_state[k], torch.Tensor):
if v.shape == own_state[k].shape:
filtered_state_dict[k] = v
else:
print(f"Skip '{k}': shape mismatch")
else:
print(f"Skip '{k}': not a tensor")
missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
if missing_keys:
print("⚠️ Missing keys:", missing_keys)
if unexpected_keys:
print("ℹ️ Unexpected keys:", unexpected_keys)
model.eval()
return model
def dpt_sep_process(wav_path, model=None, outfilename=None):
"""進行語音分離處理"""
if model is None:
model = load_dpt_model()
x, sr = torchaudio.load(wav_path)
x = x.cpu()
with torch.no_grad():
est_sources = model(x) # shape: (1, 2, T)
est_sources = est_sources.squeeze(0) # shape: (2, T)
sep_1, sep_2 = est_sources # 拆成兩個 (T,) 的 tensor
# 正規化
max_abs = x[0].abs().max().item()
sep_1 = sep_1 * max_abs / sep_1.abs().max().item()
sep_2 = sep_2 * max_abs / sep_2.abs().max().item()
# 增加 channel 維度,變為 (1, T)
sep_1 = sep_1.unsqueeze(0)
sep_2 = sep_2.unsqueeze(0)
# 儲存結果
if outfilename is not None:
torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
else:
torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
if __name__ == '__main__':
print("This module should be used via Flask or Gradio.") |