HoneyTian commited on
Commit
af4c931
·
1 Parent(s): 302f91c
.gitignore CHANGED
@@ -13,6 +13,7 @@
13
  /docs/
14
  /dotenv/
15
  /hub_datasets/
 
16
  /thirdparty/
17
  /trained_models/
18
  /temp/
 
13
  /docs/
14
  /dotenv/
15
  /hub_datasets/
16
+ /script/
17
  /thirdparty/
18
  /trained_models/
19
  /temp/
examples/spectrum_dfnet_aishell/step_2_train_model.py CHANGED
@@ -313,19 +313,19 @@ def main():
313
  snr_db_target = snr_db.to(device)
314
 
315
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
316
- if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
317
- raise AssertionError("nan or inf in speech_spec_prediction")
318
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
319
  raise AssertionError("nan or inf in speech_irm_prediction")
320
- if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
321
- raise AssertionError("nan or inf in lsnr_prediction")
322
 
323
- speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
- snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
 
327
- if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
328
- raise AssertionError("nan or inf in snr_loss")
329
 
330
  # loss = speech_loss + irm_loss + snr_loss
331
  loss = irm_loss
@@ -361,16 +361,16 @@ def main():
361
 
362
  with torch.no_grad():
363
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
364
- if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
365
- raise AssertionError("nan or inf in speech_spec_prediction")
366
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
367
  raise AssertionError("nan or inf in speech_irm_prediction")
368
- if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
369
- raise AssertionError("nan or inf in lsnr_prediction")
370
 
371
- speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
372
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
373
- snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
374
 
375
  # loss = speech_loss + irm_loss + snr_loss
376
  loss = irm_loss
 
313
  snr_db_target = snr_db.to(device)
314
 
315
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
316
+ # if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
317
+ # raise AssertionError("nan or inf in speech_spec_prediction")
318
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
319
  raise AssertionError("nan or inf in speech_irm_prediction")
320
+ # if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
321
+ # raise AssertionError("nan or inf in lsnr_prediction")
322
 
323
+ # speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
+ # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
 
327
+ # if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
328
+ # raise AssertionError("nan or inf in snr_loss")
329
 
330
  # loss = speech_loss + irm_loss + snr_loss
331
  loss = irm_loss
 
361
 
362
  with torch.no_grad():
363
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
364
+ # if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
365
+ # raise AssertionError("nan or inf in speech_spec_prediction")
366
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
367
  raise AssertionError("nan or inf in speech_irm_prediction")
368
+ # if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
369
+ # raise AssertionError("nan or inf in lsnr_prediction")
370
 
371
+ # speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
372
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
373
+ # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
374
 
375
  # loss = speech_loss + irm_loss + snr_loss
376
  loss = irm_loss
script/install_nvidia_driver.sh ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #GPU驱动安装需要先将原有的显示关闭, 重启机器, 再进行安装.
3
+ #参考链接:
4
+ #https://blog.csdn.net/kingschan/article/details/19033595
5
+ #https://blog.csdn.net/HaixWang/article/details/90408538
6
+ #
7
+ #>>> yum install -y pciutils
8
+ #查看 linux 机器上是否有 GPU
9
+ #lspci |grep -i nvidia
10
+ #
11
+ #>>> lspci |grep -i nvidia
12
+ #00:08.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1)
13
+ #
14
+ #
15
+ #NVIDIA 驱动程序下载
16
+ #先在 pytorch 上查看应该用什么 cuda 版本, 再安装对应的 cuda-toolkit cuda.
17
+ #再根据 gpu 版本下载安装对应的 nvidia 驱动
18
+ #
19
+ ## pytorch 版本
20
+ #https://pytorch.org/get-started/locally/
21
+ #
22
+ ## CUDA 下载 (好像不需要这个)
23
+ #https://developer.nvidia.com/cuda-toolkit-archive
24
+ #
25
+ ## nvidia 驱动
26
+ #https://www.nvidia.cn/Download/index.aspx?lang=cn
27
+ #http://www.nvidia.com/Download/index.aspx
28
+ #
29
+ #在下方的下拉列表中进行选择,针对您的 NVIDIA 产品确定合适的驱动。
30
+ #产品类型:
31
+ #Data Center / Tesla
32
+ #产品系列:
33
+ #T-Series
34
+ #产品家族:
35
+ #Tesla T4
36
+ #操作系统:
37
+ #Linux 64-bit
38
+ #CUDA Toolkit:
39
+ #10.2
40
+ #语言:
41
+ #Chinese (Simpleified)
42
+ #
43
+ #
44
+ #>>> mkdir -p /data/tianxing
45
+ #>>> cd /data/tianxing
46
+ #>>> wget https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
47
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
48
+ #
49
+ ## 异常:
50
+ #ERROR: The Nouveau kernel driver is currently in use by your system. This driver is incompatible with the NVIDIA driver, and must be disabled before proceeding. Please consult the NVIDIA driver README and your
51
+ #Linux distribution's documentation for details on how to correctly disable the Nouveau kernel driver.
52
+ #[OK]
53
+ #
54
+ #For some distributions, Nouveau can be disabled by adding a file in the modprobe configuration directory. Would you like nvidia-installer to attempt to create this modprobe file for you?
55
+ #[NO]
56
+ #
57
+ #ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
58
+ #page at www.nvidia.com.
59
+ #[OK]
60
+ #
61
+ ## 参考链接:
62
+ #https://blog.csdn.net/kingschan/article/details/19033595
63
+ #
64
+ ## 禁用原有的显卡驱动 nouveau
65
+ #>>> echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
66
+ #>>> sudo dracut --force
67
+ ## 重启
68
+ #>>> reboot
69
+ #
70
+ #>>> init 3
71
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
72
+ #
73
+ ## 异常
74
+ #ERROR: Unable to find the kernel source tree for the currently running kernel. Please make sure you have installed the kernel source files for your kernel and that they are properly configured; on Red Hat Linux systems, for example, be sure you have the 'kernel-source' or 'kernel-devel' RPM installed. If you know the correct kernel source files are installed, you may specify the kernel source path with the '--kernel-source-path' command line option.
75
+ #[OK]
76
+ #ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
77
+ #page at www.nvidia.com.
78
+ #[OK]
79
+ #
80
+ ## 参考链接
81
+ ## https://blog.csdn.net/HaixWang/article/details/90408538
82
+ #
83
+ #>>> uname -r
84
+ #3.10.0-1160.49.1.el7.x86_64
85
+ #>>> yum install kernel-devel kernel-headers -y
86
+ #>>> yum info kernel-devel kernel-headers
87
+ #>>> yum install -y "kernel-devel-uname-r == $(uname -r)"
88
+ #>>> yum -y distro-sync
89
+ #
90
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
91
+ #
92
+ ## 安装成功
93
+ #WARNING: nvidia-installer was forced to guess the X library path '/usr/lib64' and X module path '/usr/lib64/xorg/modules'; these paths were not queryable from the system. If X fails to find the NVIDIA X driver
94
+ #module, please install the `pkg-config` utility and the X.Org SDK/development package for your distribution and reinstall the driver.
95
+ #[OK]
96
+ #Install NVIDIA's 32-bit compatibility libraries?
97
+ #[YES]
98
+ #Installation of the kernel module for the NVIDIA Accelerated Graphics Driver for Linux-x86_64 (version 440.118.02) is now complete.
99
+ #[OK]
100
+ #
101
+ #
102
+ ## 查看 GPU 使用情况; watch -n 1 -d nvidia-smi 每1秒刷新一次.
103
+ #>>> nvidia-smi
104
+ #Thu Mar 9 12:00:37 2023
105
+ #+-----------------------------------------------------------------------------+
106
+ #| NVIDIA-SMI 440.118.02 Driver Version: 440.118.02 CUDA Version: 10.2 |
107
+ #|-------------------------------+----------------------+----------------------+
108
+ #| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
109
+ #| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
110
+ #|===============================+======================+======================|
111
+ #| 0 Tesla T4 Off | 00000000:00:08.0 Off | Off |
112
+ #| N/A 54C P0 22W / 70W | 0MiB / 16127MiB | 0% Default |
113
+ #+-------------------------------+----------------------+----------------------+
114
+ #
115
+ #+-----------------------------------------------------------------------------+
116
+ #| Processes: GPU Memory |
117
+ #| GPU PID Type Process name Usage |
118
+ #|=============================================================================|
119
+ #| No running processes found |
120
+ #+-----------------------------------------------------------------------------+
121
+ #
122
+ #
123
+
124
+ # params
125
+ stage=1
126
+ nvidia_driver_filename=https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
127
+
128
+ # parse options
129
+ while true; do
130
+ [ -z "${1:-}" ] && break; # break if there are no arguments
131
+ case "$1" in
132
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
133
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
134
+ old_value="(eval echo \\$$name)";
135
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
136
+ was_bool=true;
137
+ else
138
+ was_bool=false;
139
+ fi
140
+
141
+ # Set the variable to the right value-- the escaped quotes make it work if
142
+ # the option had spaces, like --cmd "queue.pl -sync y"
143
+ eval "${name}=\"$2\"";
144
+
145
+ # Check that Boolean-valued arguments are really Boolean.
146
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
147
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
148
+ exit 1;
149
+ fi
150
+ shift 2;
151
+ ;;
152
+
153
+ *) break;
154
+ esac
155
+ done
156
+
157
+ echo "stage: ${stage}";
158
+
159
+ yum -y install wget
160
+ yum -y install sudo
161
+
162
+ if [ ${stage} -eq 0 ]; then
163
+ mkdir -p /data/dep
164
+ cd /data/dep || echo 1;
165
+ wget -P /data/dep ${nvidia_driver_filename}
166
+
167
+ echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
168
+ sudo dracut --force
169
+ # 重启
170
+ reboot
171
+ elif [ ${stage} -eq 1 ]; then
172
+ init 3
173
+
174
+ yum install -y kernel-devel kernel-headers
175
+ yum info kernel-devel kernel-headers
176
+ yum install -y "kernel-devel-uname-r == $(uname -r)"
177
+ yum -y distro-sync
178
+
179
+ cd /data/dep || echo 1;
180
+
181
+ # 安装时, 需要回车三下.
182
+ sh NVIDIA-Linux-x86_64-440.118.02.run
183
+ nvidia-smi
184
+ fi
script/install_python.sh ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # 参数:
4
+ python_version="3.6.5";
5
+ system_version="centos";
6
+
7
+
8
+ # parse options
9
+ while true; do
10
+ [ -z "${1:-}" ] && break; # break if there are no arguments
11
+ case "$1" in
12
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
13
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
14
+ old_value="(eval echo \\$$name)";
15
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
16
+ was_bool=true;
17
+ else
18
+ was_bool=false;
19
+ fi
20
+
21
+ # Set the variable to the right value-- the escaped quotes make it work if
22
+ # the option had spaces, like --cmd "queue.pl -sync y"
23
+ eval "${name}=\"$2\"";
24
+
25
+ # Check that Boolean-valued arguments are really Boolean.
26
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
27
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
28
+ exit 1;
29
+ fi
30
+ shift 2;
31
+ ;;
32
+
33
+ *) break;
34
+ esac
35
+ done
36
+
37
+ echo "python_version: ${python_version}";
38
+ echo "system_version: ${system_version}";
39
+
40
+
41
+ if [ ${system_version} = "centos" ]; then
42
+ # 安装 python 开发编译环境
43
+ yum -y groupinstall "Development tools"
44
+ yum -y install zlib-devel bzip2-devel openssl-devel ncurses-devel sqlite-devel readline-devel tk-devel gdbm-devel db4-devel libpcap-devel xz-devel
45
+ yum install libffi-devel -y
46
+ yum install -y wget
47
+ yum install -y make
48
+
49
+ mkdir -p /data/dep
50
+ cd /data/dep || exit 1;
51
+ if [ ! -e Python-${python_version}.tgz ]; then
52
+ wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
53
+ fi
54
+
55
+ cd /data/dep || exit 1;
56
+ if [ ! -d Python-${python_version} ]; then
57
+ tar -zxvf Python-${python_version}.tgz
58
+ cd /data/dep/Python-${python_version} || exit 1;
59
+ fi
60
+
61
+ mkdir /usr/local/python-${python_version}
62
+ ./configure --prefix=/usr/local/python-${python_version}
63
+ make && make install
64
+
65
+ /usr/local/python-${python_version}/bin/python3 -V
66
+ /usr/local/python-${python_version}/bin/pip3 -V
67
+
68
+ rm -rf /usr/local/bin/python3
69
+ rm -rf /usr/local/bin/pip3
70
+ ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
71
+ ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
72
+
73
+ python3 -V
74
+ pip3 -V
75
+
76
+ elif [ ${system_version} = "ubuntu" ]; then
77
+ # 安装 python 开发编译环境
78
+ # https://zhuanlan.zhihu.com/p/506491209
79
+
80
+ # 刷新软件包目录
81
+ sudo apt update
82
+ # 列出当前可用的更新
83
+ sudo apt list --upgradable
84
+ # 如上一步提示有可以更新的项目,则执行更新
85
+ sudo apt -y upgrade
86
+ # 安装 GCC 编译器
87
+ sudo apt install gcc
88
+ # 检查安装是否成功
89
+ gcc -v
90
+
91
+ # 安装依赖
92
+ sudo apt install -y build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libreadline-dev libffi-dev libbz2-dev liblzma-dev sqlite3 libsqlite3-dev tk-dev uuid-dev libgdbm-compat-dev
93
+
94
+ mkdir -p /data/dep
95
+ cd /data/dep || exit 1;
96
+ if [ ! -e Python-${python_version}.tgz ]; then
97
+ # sudo wget -P /data/dep https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tgz
98
+ sudo wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
99
+ fi
100
+
101
+ cd /data/dep || exit 1;
102
+ if [ ! -d Python-${python_version} ]; then
103
+ # tar -zxvf Python-3.6.5.tgz
104
+ tar -zxvf Python-${python_version}.tgz
105
+ # cd /data/dep/Python-3.6.5
106
+ cd /data/dep/Python-${python_version} || exit 1;
107
+ fi
108
+
109
+ # mkdir /usr/local/python-3.6.5
110
+ mkdir /usr/local/python-${python_version}
111
+
112
+ # 检查依赖与配置编译
113
+ # sudo ./configure --prefix=/usr/local/python-3.6.5 --enable-optimizations --with-lto --enable-shared
114
+ sudo ./configure --prefix=/usr/local/python-${python_version} --enable-optimizations --with-lto --enable-shared
115
+ cpu_count=$(cat /proc/cpuinfo | grep processor | wc -l)
116
+ # sudo make -j 4
117
+ sudo make -j "${cpu_count}"
118
+
119
+ /usr/local/python-${python_version}/bin/python3 -V
120
+ /usr/local/python-${python_version}/bin/pip3 -V
121
+
122
+ rm -rf /usr/local/bin/python3
123
+ rm -rf /usr/local/bin/pip3
124
+ ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
125
+ ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
126
+
127
+ python3 -V
128
+ pip3 -V
129
+ fi
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py CHANGED
@@ -831,30 +831,30 @@ class SpectrumDfNet(nn.Module):
831
  if torch.any(mask > 1) or torch.any(mask < 0):
832
  raise AssertionError
833
 
834
- spec_m = self.mask.forward(spec, mask)
835
-
836
- # lsnr shape: [batch_size, time_steps, 1]
837
- lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
838
- # lsnr shape: [batch_size, 1, time_steps]
839
-
840
- df_coefs = self.df_decoder.forward(emb, c0)
841
- df_coefs = self.df_out_transform(df_coefs)
842
- # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
843
-
844
- spec_e = self.df_op.forward(spec.clone(), df_coefs)
845
- # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
846
-
847
- spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
848
-
849
- spec_e = torch.squeeze(spec_e, dim=1)
850
- spec_e = spec_e.permute(0, 2, 1, 3)
851
- # spec_e shape: [batch_size, spec_bins, time_steps, 2]
852
 
853
  mask = torch.squeeze(mask, dim=1)
854
  mask = mask.permute(0, 2, 1)
855
  # mask shape: [batch_size, spec_bins, time_steps]
856
 
857
- return spec_e, mask, lsnr
858
 
859
 
860
  class SpectrumDfNetPretrainedModel(SpectrumDfNet):
 
831
  if torch.any(mask > 1) or torch.any(mask < 0):
832
  raise AssertionError
833
 
834
+ # spec_m = self.mask.forward(spec, mask)
835
+ #
836
+ # # lsnr shape: [batch_size, time_steps, 1]
837
+ # lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
838
+ # # lsnr shape: [batch_size, 1, time_steps]
839
+ #
840
+ # df_coefs = self.df_decoder.forward(emb, c0)
841
+ # df_coefs = self.df_out_transform(df_coefs)
842
+ # # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
843
+ #
844
+ # spec_e = self.df_op.forward(spec.clone(), df_coefs)
845
+ # # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
846
+ #
847
+ # spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
848
+ #
849
+ # spec_e = torch.squeeze(spec_e, dim=1)
850
+ # spec_e = spec_e.permute(0, 2, 1, 3)
851
+ # # spec_e shape: [batch_size, spec_bins, time_steps, 2]
852
 
853
  mask = torch.squeeze(mask, dim=1)
854
  mask = mask.permute(0, 2, 1)
855
  # mask shape: [batch_size, spec_bins, time_steps]
856
 
857
+ return None, mask, lsnr
858
 
859
 
860
  class SpectrumDfNetPretrainedModel(SpectrumDfNet):