Spaces:
Running
Running
import os | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import scipy | |
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset | |
from scipy import signal | |
from tqdm import tqdm | |
from data_reader import Config | |
matplotlib.use('agg') | |
def plot_result(epoch, num, figure_dir, preds, X, Y, mode="valid"): | |
config = Config() | |
for i in range(min(num, len(X))): | |
t, noisy_signal = scipy.signal.istft( | |
X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros' | |
) | |
t, ideal_denoised_signal = scipy.signal.istft( | |
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
t, denoised_signal = scipy.signal.istft( | |
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
plt.figure(i) | |
fig_size = plt.gcf().get_size_inches() | |
plt.gcf().set_size_inches(fig_size * [1.5, 1.5]) | |
plt.subplot(4, 2, 1) | |
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=2) | |
plt.title("Noisy signal") | |
plt.gca().set_xticklabels([]) | |
plt.subplot(4, 2, 2) | |
plt.plot(t, noisy_signal, label='Noisy signal', linewidth=0.1) | |
signal_ylim = plt.gca().get_ylim() | |
plt.gca().set_xticklabels([]) | |
plt.legend(loc='lower left') | |
plt.margins(x=0) | |
plt.subplot(4, 2, 3) | |
plt.pcolormesh(Y[i, :, :, 0], vmin=0, vmax=1) | |
plt.gca().set_xticklabels([]) | |
plt.title("Y") | |
plt.subplot(4, 2, 4) | |
plt.pcolormesh(preds[i, :, :, 0], vmin=0, vmax=1) | |
plt.title("$\hat{Y}$") | |
plt.gca().set_xticklabels([]) | |
plt.subplot(4, 2, 5) | |
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], vmin=0, vmax=2) | |
plt.title("Ideal denoised signal") | |
plt.gca().set_xticklabels([]) | |
plt.subplot(4, 2, 6) | |
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=2) | |
plt.title("Denoised signal") | |
plt.gca().set_xticklabels([]) | |
plt.subplot(4, 2, 7) | |
plt.plot(t, ideal_denoised_signal, label='Ideal denoised signal', linewidth=0.1) | |
plt.ylim(signal_ylim) | |
plt.xlabel("Time (s)") | |
plt.legend(loc='lower left') | |
plt.margins(x=0) | |
plt.subplot(4, 2, 8) | |
plt.plot(t, denoised_signal, label='Denoised signal', linewidth=0.1) | |
plt.ylim(signal_ylim) | |
plt.xlabel("Time (s)") | |
plt.legend(loc='lower left') | |
plt.margins(x=0) | |
plt.tight_layout() | |
plt.gcf().align_labels() | |
plt.savefig(os.path.join(figure_dir, "epoch{:03d}_{:03d}_{:}.png".format(epoch, i, mode)), bbox_inches='tight') | |
# plt.savefig(os.path.join(figure_dir, "epoch%03d_%03d.pdf" % (epoch, i)), bbox_inches='tight') | |
plt.close(i) | |
return 0 | |
def plot_result_thread(i, epoch, preds, X, Y, figure_dir, mode="valid"): | |
config = Config() | |
t, noisy_signal = scipy.signal.istft( | |
X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros' | |
) | |
t, ideal_denoised_signal = scipy.signal.istft( | |
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
t, denoised_signal = scipy.signal.istft( | |
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
plt.figure(i) | |
fig_size = plt.gcf().get_size_inches() | |
plt.gcf().set_size_inches(fig_size * [1.5, 1.5]) | |
plt.subplot(4, 2, 1) | |
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=2) | |
plt.title("Noisy signal") | |
plt.gca().set_xticklabels([]) | |
plt.subplot(4, 2, 2) | |
plt.plot(t, noisy_signal, 'k', label='Noisy signal', linewidth=0.5) | |
signal_ylim = plt.gca().get_ylim() | |
plt.gca().set_xticklabels([]) | |
plt.legend(loc='lower left') | |
plt.margins(x=0) | |
plt.subplot(4, 2, 3) | |
plt.pcolormesh(Y[i, :, :, 0], vmin=0, vmax=1) | |
plt.gca().set_xticklabels([]) | |
plt.title("Y") | |
plt.subplot(4, 2, 4) | |
plt.pcolormesh(preds[i, :, :, 0], vmin=0, vmax=1) | |
plt.title("$\hat{Y}$") | |
plt.gca().set_xticklabels([]) | |
plt.subplot(4, 2, 5) | |
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], vmin=0, vmax=2) | |
plt.title("Ideal denoised signal") | |
plt.gca().set_xticklabels([]) | |
plt.subplot(4, 2, 6) | |
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=2) | |
plt.title("Denoised signal") | |
plt.gca().set_xticklabels([]) | |
plt.subplot(4, 2, 7) | |
plt.plot(t, ideal_denoised_signal, 'k', label='Ideal denoised signal', linewidth=0.5) | |
plt.ylim(signal_ylim) | |
plt.xlabel("Time (s)") | |
plt.legend(loc='lower left') | |
plt.margins(x=0) | |
plt.subplot(4, 2, 8) | |
plt.plot(t, denoised_signal, 'k', label='Denoised signal', linewidth=0.5) | |
plt.ylim(signal_ylim) | |
plt.xlabel("Time (s)") | |
plt.legend(loc='lower left') | |
plt.margins(x=0) | |
plt.tight_layout() | |
plt.gcf().align_labels() | |
plt.savefig(os.path.join(figure_dir, "epoch{:03d}_{:03d}_{:}.png".format(epoch, i, mode)), bbox_inches='tight') | |
plt.close(i) | |
return 0 | |
def postprocessing_test( | |
i, preds, X, fname, figure_dir=None, result_dir=None, signal_FT=None, noise_FT=None, data_dir=None | |
): | |
if (figure_dir is not None) or (result_dir is not None): | |
config = Config() | |
t1, noisy_signal = scipy.signal.istft( | |
X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros' | |
) | |
t1, denoised_signal = scipy.signal.istft( | |
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
t1, denoised_noise = scipy.signal.istft( | |
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * (1 - preds[i, :, :, 0]), | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
t1, signal = scipy.signal.istft( | |
signal_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros' | |
) | |
t1, noise = scipy.signal.istft( | |
noise_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros' | |
) | |
if result_dir is not None: | |
try: | |
np.savez( | |
os.path.join(result_dir, fname[i].decode()), | |
preds=preds[i], | |
X=X[i], | |
signal_FT=signal_FT[i], | |
noise_FT=noise_FT[i], | |
noisy_signal=noisy_signal, | |
denoised_signal=denoised_signal, | |
denoised_noise=denoised_noise, | |
signal=signal, | |
noise=noise, | |
) | |
except FileNotFoundError: | |
os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True) | |
np.savez( | |
os.path.join(result_dir, fname[i].decode()), | |
preds=preds[i], | |
X=X[i], | |
signal_FT=signal_FT[i], | |
noise_FT=noise_FT[i], | |
noisy_signal=noisy_signal, | |
denoised_signal=denoised_signal, | |
denoised_noise=denoised_noise, | |
signal=signal, | |
noise=noise, | |
) | |
if figure_dir is not None: | |
t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2]) | |
f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1]) | |
raw_data = None | |
if data_dir is not None: | |
raw_data = np.load(os.path.join(data_dir, fname[i].decode().split('/')[-1])) | |
itp = raw_data['itp'] | |
its = raw_data['its'] | |
ix1 = (750 - 50) / 100 | |
ix2 = (750 + (its - itp) + 50) / 100 | |
if ix2 - ix1 > 3: | |
ix2 = ix1 + 3 | |
box = dict(boxstyle='round', facecolor='white', alpha=1) | |
text_loc = [0.05, 0.8] | |
plt.figure(i) | |
fig_size = plt.gcf().get_size_inches() | |
plt.gcf().set_size_inches(fig_size * [1, 2]) | |
plt.subplot(511) | |
plt.pcolormesh(t_FT, f_FT, np.abs(signal_FT[i, :, :]), vmin=0, vmax=1) | |
plt.gca().set_xticklabels([]) | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(i)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(512) | |
plt.pcolormesh(t_FT, f_FT, np.abs(noise_FT[i, :, :]), vmin=0, vmax=1) | |
plt.gca().set_xticklabels([]) | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(ii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(513) | |
plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=1) | |
plt.ylabel("Frequency (Hz)", fontsize='large') | |
plt.gca().set_xticklabels([]) | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(iii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(514) | |
plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=1) | |
plt.gca().set_xticklabels([]) | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(iv)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(515) | |
plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1], vmin=0, vmax=1) | |
plt.xlabel("Time (s)", fontsize='large') | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(v)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
try: | |
plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
# plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
except FileNotFoundError: | |
os.makedirs( | |
os.path.dirname(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png')), exist_ok=True | |
) | |
plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
# plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
plt.close(i) | |
text_loc = [0.05, 0.8] | |
plt.figure(i) | |
fig_size = plt.gcf().get_size_inches() | |
plt.gcf().set_size_inches(fig_size * [1, 2]) | |
ax3 = plt.subplot(513) | |
plt.plot(t1, noisy_signal, 'k', linewidth=0.5, label='Noisy signal') | |
plt.legend(loc='lower left', fontsize='medium') | |
plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
plt.ylim([-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))]) | |
signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))] | |
plt.ylim(signal_ylim) | |
plt.ylabel("Amplitude", fontsize='large') | |
plt.gca().set_xticklabels([]) | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(iii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
ax1 = plt.subplot(511) | |
plt.plot(t1, signal, 'k', linewidth=0.5, label='Signal') | |
plt.legend(loc='lower left', fontsize='medium') | |
plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
plt.ylim(signal_ylim) | |
plt.gca().set_xticklabels([]) | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(i)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(512) | |
plt.plot(t1, noise, 'k', linewidth=0.5, label='Noise') | |
plt.legend(loc='lower left', fontsize='medium') | |
plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
plt.ylim([-np.max(np.abs(noise)), np.max(np.abs(noise))]) | |
noise_ylim = [-np.max(np.abs(noise[100:-100])), np.max(np.abs(noise[100:-100]))] | |
plt.ylim(noise_ylim) | |
plt.gca().set_xticklabels([]) | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(ii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
ax4 = plt.subplot(514) | |
plt.plot(t1, denoised_signal, 'k', linewidth=0.5, label='Recovered signal') | |
plt.legend(loc='lower left', fontsize='medium') | |
plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
plt.ylim(signal_ylim) | |
plt.gca().set_xticklabels([]) | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(iv)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(515) | |
plt.plot(t1, denoised_noise, 'k', linewidth=0.5, label='Recovered noise') | |
plt.legend(loc='lower left', fontsize='medium') | |
plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
plt.xlabel("Time (s)", fontsize='large') | |
plt.ylim(noise_ylim) | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(v)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
if data_dir is not None: | |
axins = inset_axes( | |
ax1, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax1.transAxes | |
) | |
axins.plot(t1, signal, 'k', linewidth=0.5) | |
x1, x2 = ix1, ix2 | |
y1 = -np.max(np.abs(signal[(t1 > ix1) & (t1 < ix2)])) | |
y2 = -y1 | |
axins.set_xlim(x1, x2) | |
axins.set_ylim(y1, y2) | |
plt.xticks(visible=False) | |
plt.yticks(visible=False) | |
mark_inset(ax1, axins, loc1=1, loc2=3, fc="none", ec="0.5") | |
axins = inset_axes( | |
ax3, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.3), bbox_transform=ax3.transAxes | |
) | |
axins.plot(t1, noisy_signal, 'k', linewidth=0.5) | |
x1, x2 = ix1, ix2 | |
axins.set_xlim(x1, x2) | |
axins.set_ylim(y1, y2) | |
plt.xticks(visible=False) | |
plt.yticks(visible=False) | |
mark_inset(ax3, axins, loc1=1, loc2=3, fc="none", ec="0.5") | |
axins = inset_axes( | |
ax4, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax4.transAxes | |
) | |
axins.plot(t1, denoised_signal, 'k', linewidth=0.5) | |
x1, x2 = ix1, ix2 | |
axins.set_xlim(x1, x2) | |
axins.set_ylim(y1, y2) | |
plt.xticks(visible=False) | |
plt.yticks(visible=False) | |
mark_inset(ax4, axins, loc1=1, loc2=3, fc="none", ec="0.5") | |
plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_wave.png'), bbox_inches='tight') | |
# plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_wave.pdf'), bbox_inches='tight') | |
plt.close(i) | |
return | |
def postprocessing_pred(i, preds, X, fname, figure_dir=None, result_dir=None): | |
if (result_dir is not None) or (figure_dir is not None): | |
config = Config() | |
t1, noisy_signal = scipy.signal.istft( | |
(X[i, :, :, 0] + X[i, :, :, 1] * 1j), | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
t1, denoised_signal = scipy.signal.istft( | |
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
t1, denoised_noise = scipy.signal.istft( | |
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1], | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
if result_dir is not None: | |
try: | |
np.savez( | |
os.path.join(result_dir, fname[i]), | |
noisy_signal=noisy_signal, | |
denoised_signal=denoised_signal, | |
denoised_noise=denoised_noise, | |
t=t1, | |
) | |
except FileNotFoundError: | |
os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i]))) | |
np.savez( | |
os.path.join(result_dir, fname[i]), | |
noisy_signal=noisy_signal, | |
denoised_signal=denoised_signal, | |
denoised_noise=denoised_noise, | |
t=t1, | |
) | |
if figure_dir is not None: | |
t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2]) | |
f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1]) | |
box = dict(boxstyle='round', facecolor='white', alpha=1) | |
text_loc = [0.05, 0.77] | |
plt.figure(i) | |
fig_size = plt.gcf().get_size_inches() | |
plt.gcf().set_size_inches(fig_size * [1, 1.2]) | |
vmax = np.std(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j)) * 1.8 | |
plt.subplot(311) | |
plt.pcolormesh( | |
t_FT, | |
f_FT, | |
np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), | |
vmin=0, | |
vmax=vmax, | |
shading='auto', | |
label='Noisy signal', | |
) | |
plt.gca().set_xticklabels([]) | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(i)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(312) | |
plt.pcolormesh( | |
t_FT, | |
f_FT, | |
np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], | |
vmin=0, | |
vmax=vmax, | |
shading='auto', | |
label='Recovered signal', | |
) | |
plt.gca().set_xticklabels([]) | |
plt.ylabel("Frequency (Hz)", fontsize='large') | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(ii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(313) | |
plt.pcolormesh( | |
t_FT, | |
f_FT, | |
np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1], | |
vmin=0, | |
vmax=vmax, | |
shading='auto', | |
label='Recovered noise', | |
) | |
plt.xlabel("Time (s)", fontsize='large') | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(iii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
try: | |
plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
except FileNotFoundError: | |
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png')), exist_ok=True) | |
plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
plt.close(i) | |
plt.figure(i) | |
fig_size = plt.gcf().get_size_inches() | |
plt.gcf().set_size_inches(fig_size * [1, 1.2]) | |
ax4 = plt.subplot(311) | |
plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5) | |
plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))] | |
plt.ylim(signal_ylim) | |
plt.gca().set_xticklabels([]) | |
plt.legend(loc='lower left', fontsize='medium') | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(i)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
ax5 = plt.subplot(312) | |
plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5) | |
plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
plt.ylim(signal_ylim) | |
plt.gca().set_xticklabels([]) | |
plt.ylabel("Amplitude", fontsize='large') | |
plt.legend(loc='lower left', fontsize='medium') | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(ii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(313) | |
plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5) | |
plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
plt.ylim(signal_ylim) | |
plt.xlabel("Time (s)", fontsize='large') | |
plt.legend(loc='lower left', fontsize='medium') | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(iii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_wave.png'), bbox_inches='tight') | |
# plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight') | |
plt.close(i) | |
return | |
def save_results(mask, X, fname, t0, save_signal=True, save_noise=True, result_dir="results"): | |
config = Config() | |
if save_signal: | |
_, denoised_signal = scipy.signal.istft( | |
(X[..., 0] + X[..., 1] * 1j) * mask[..., 0], | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) # nbt, nch, nst, nt | |
denoised_signal = np.transpose(denoised_signal, [0, 3, 2, 1]) # nbt, nt, nst, nch, | |
if save_noise: | |
_, denoised_noise = scipy.signal.istft( | |
(X[..., 0] + X[..., 1] * 1j) * mask[..., 1], | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
denoised_noise = np.transpose(denoised_noise, [0, 3, 2, 1]) | |
if not os.path.exists(result_dir): | |
os.makedirs(result_dir) | |
for i in range(len(X)): | |
np.savez( | |
os.path.join(result_dir, fname[i]), | |
data=denoised_signal[i] if save_signal else None, | |
noise=denoised_noise[i] if save_noise else None, | |
t0=t0[i], | |
) | |
def plot_figures(mask, X, fname, figure_dir="figures"): | |
config = Config() | |
# plot the last channel | |
mask = mask[-1, -1, ...] # nch, nst, nf, nt, 2 => nf, nt, 2 | |
X = X[-1, -1, ...] | |
t1, noisy_signal = scipy.signal.istft( | |
(X[..., 0] + X[..., 1] * 1j), | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
t1, denoised_signal = scipy.signal.istft( | |
(X[..., 0] + X[..., 1] * 1j) * mask[..., 0], | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
t1, denoised_noise = scipy.signal.istft( | |
(X[..., 0] + X[..., 1] * 1j) * mask[..., 1], | |
fs=config.fs, | |
nperseg=config.nperseg, | |
nfft=config.nfft, | |
boundary='zeros', | |
) | |
if not os.path.exists(figure_dir): | |
os.makedirs(figure_dir) | |
t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[1]) | |
f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[0]) | |
box = dict(boxstyle='round', facecolor='white', alpha=1) | |
text_loc = [0.05, 0.77] | |
plt.figure() | |
fig_size = plt.gcf().get_size_inches() | |
plt.gcf().set_size_inches(fig_size * [1, 1.2]) | |
vmax = np.std(np.abs(X[:, :, 0] + X[:, :, 1] * 1j)) * 1.8 | |
plt.subplot(311) | |
plt.pcolormesh( | |
t_FT, | |
f_FT, | |
np.abs(X[:, :, 0] + X[:, :, 1] * 1j), | |
vmin=0, | |
vmax=vmax, | |
shading='auto', | |
label='Noisy signal', | |
) | |
plt.gca().set_xticklabels([]) | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(i)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(312) | |
plt.pcolormesh( | |
t_FT, | |
f_FT, | |
np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 0], | |
vmin=0, | |
vmax=vmax, | |
shading='auto', | |
label='Recovered signal', | |
) | |
plt.gca().set_xticklabels([]) | |
plt.ylabel("Frequency (Hz)", fontsize='large') | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(ii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(313) | |
plt.pcolormesh( | |
t_FT, | |
f_FT, | |
np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 1], | |
vmin=0, | |
vmax=vmax, | |
shading='auto', | |
label='Recovered noise', | |
) | |
plt.xlabel("Time (s)", fontsize='large') | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(iii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
try: | |
plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
except FileNotFoundError: | |
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png')), exist_ok=True) | |
plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
plt.close() | |
plt.figure() | |
fig_size = plt.gcf().get_size_inches() | |
plt.gcf().set_size_inches(fig_size * [1, 1.2]) | |
ax4 = plt.subplot(311) | |
plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5) | |
plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
signal_ylim = [-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))] | |
if signal_ylim[0] != signal_ylim[1]: | |
plt.ylim(signal_ylim) | |
plt.gca().set_xticklabels([]) | |
plt.legend(loc='lower left', fontsize='medium') | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(i)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
ax5 = plt.subplot(312) | |
plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5) | |
plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
if signal_ylim[0] != signal_ylim[1]: | |
plt.ylim(signal_ylim) | |
plt.gca().set_xticklabels([]) | |
plt.ylabel("Amplitude", fontsize='large') | |
plt.legend(loc='lower left', fontsize='medium') | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(ii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.subplot(313) | |
plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5) | |
plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
if signal_ylim[0] != signal_ylim[1]: | |
plt.ylim(signal_ylim) | |
plt.xlabel("Time (s)", fontsize='large') | |
plt.legend(loc='lower left', fontsize='medium') | |
plt.text( | |
text_loc[0], | |
text_loc[1], | |
'(iii)', | |
horizontalalignment='center', | |
transform=plt.gca().transAxes, | |
fontsize="medium", | |
fontweight="bold", | |
bbox=box, | |
) | |
plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_wave.png'), bbox_inches='tight') | |
# plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight') | |
plt.close() | |
return | |
if __name__ == "__main__": | |
pass | |