File size: 3,961 Bytes
64ceedd
d8be50a
956e325
 
d8be50a
64ceedd
 
 
 
 
d8be50a
 
 
 
 
 
64ceedd
 
 
d8be50a
64ceedd
d8be50a
64ceedd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a34a59d
 
 
64ceedd
 
d8be50a
 
 
 
64ceedd
 
 
d8be50a
 
 
 
 
64ceedd
67c1496
 
 
 
 
64ceedd
 
d8be50a
 
67c1496
 
 
 
 
 
 
 
 
 
 
 
d8be50a
64ceedd
d8be50a
 
 
 
 
 
64ceedd
 
 
 
d8be50a
64ceedd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c45cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# DPTNet_quant_sep.py

import warnings
warnings.filterwarnings("ignore", message="Failed to initialize NumPy: _ARRAY_API not found")

import os
import torch
import numpy as np
import torchaudio
from huggingface_hub import hf_hub_download

# 動態導入 asteroid_test 中的 DPTNet
try:
    from . import asteroid_test
except ImportError as e:
    raise ImportError("無法載入 asteroid_test 模組,請確認該模組與訓練時相同") from e

torchaudio.set_audio_backend("sox_io")


def get_conf():
    """取得模型參數設定"""
    conf_filterbank = {
        'n_filters': 64,
        'kernel_size': 16,
        'stride': 8
    }

    conf_masknet = {
        'in_chan': 64,
        'n_src': 2,
        'out_chan': 64,
        'ff_hid': 256,
        'ff_activation': "relu",
        'norm_type': "gLN",
        'chunk_size': 100,
        'hop_size': 50,
        'n_repeats': 2,
        'mask_act': 'sigmoid',
        'bidirectional': True,
        'dropout': 0
    }
    return conf_filterbank, conf_masknet


def load_dpt_model():
    print('Load Separation Model...')

    speech_sep_token = os.getenv("SpeechSeparation")
    if not speech_sep_token:
        raise EnvironmentError("環境變數 SpeechSeparation 未設定!")

    model_path = hf_hub_download(
        repo_id="DeepLearning101/speech-separation",
        filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
        token=speech_sep_token
    )

    conf_filterbank, conf_masknet = get_conf()

    try:
        model_class = getattr(asteroid_test, "DPTNet")
        model = model_class(**conf_filterbank, **conf_masknet)
    except Exception as e:
        raise RuntimeError("模型結構錯誤:請確認 asteroid_test.py 是否與訓練時相同") from e

    model = torch.quantization.quantize_dynamic(
        model,
        {torch.nn.LSTM, torch.nn.Linear},
        dtype=torch.qint8
    )

    state_dict = torch.load(model_path, map_location="cpu")
    own_state = model.state_dict()

    # 只保留是 torch.Tensor 的 key-value pairs
    filtered_state_dict = {}
    for k, v in state_dict.items():
        if k in own_state:
            if isinstance(v, torch.Tensor) and isinstance(own_state[k], torch.Tensor):
                if v.shape == own_state[k].shape:
                    filtered_state_dict[k] = v
                else:
                    print(f"Skip '{k}': shape mismatch")
            else:
                print(f"Skip '{k}': not a tensor")

    missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)

    if missing_keys:
        print("⚠️ Missing keys:", missing_keys)
    if unexpected_keys:
        print("ℹ️ Unexpected keys:", unexpected_keys)

    model.eval()
    return model


def dpt_sep_process(wav_path, model=None, outfilename=None):
    """進行語音分離處理"""
    if model is None:
        model = load_dpt_model()

    x, sr = torchaudio.load(wav_path)
    x = x.cpu()

    with torch.no_grad():
        est_sources = model(x)  # shape: (1, 2, T)

    est_sources = est_sources.squeeze(0)  # shape: (2, T)
    sep_1, sep_2 = est_sources  # 拆成兩個 (T,) 的 tensor

    # 正規化
    max_abs = x[0].abs().max().item()
    sep_1 = sep_1 * max_abs / sep_1.abs().max().item()
    sep_2 = sep_2 * max_abs / sep_2.abs().max().item()

    # 增加 channel 維度,變為 (1, T)
    sep_1 = sep_1.unsqueeze(0)
    sep_2 = sep_2.unsqueeze(0)

    # 儲存結果
    if outfilename is not None:
        torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
        torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
        torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
    else:
        torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
        torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)


if __name__ == '__main__':
    print("This module should be used via Flask or Gradio.")