|
import inspect |
|
import typing |
|
from functools import wraps |
|
|
|
from . import util |
|
|
|
|
|
def format_figure(func): |
|
"""Decorator for formatting figures produced by the code below. |
|
See :py:func:`audiotools.core.util.format_figure` for more. |
|
|
|
Parameters |
|
---------- |
|
func : Callable |
|
Plotting function that is decorated by this function. |
|
|
|
""" |
|
|
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
f_keys = inspect.signature(util.format_figure).parameters.keys() |
|
f_kwargs = {} |
|
for k, v in list(kwargs.items()): |
|
if k in f_keys: |
|
kwargs.pop(k) |
|
f_kwargs[k] = v |
|
func(*args, **kwargs) |
|
util.format_figure(**f_kwargs) |
|
|
|
return wrapper |
|
|
|
|
|
class DisplayMixin: |
|
@format_figure |
|
def specshow( |
|
self, |
|
preemphasis: bool = False, |
|
x_axis: str = "time", |
|
y_axis: str = "linear", |
|
n_mels: int = 128, |
|
**kwargs, |
|
): |
|
"""Displays a spectrogram, using ``librosa.display.specshow``. |
|
|
|
Parameters |
|
---------- |
|
preemphasis : bool, optional |
|
Whether or not to apply preemphasis, which makes high |
|
frequency detail easier to see, by default False |
|
x_axis : str, optional |
|
How to label the x axis, by default "time" |
|
y_axis : str, optional |
|
How to label the y axis, by default "linear" |
|
n_mels : int, optional |
|
If displaying a mel spectrogram with ``y_axis = "mel"``, |
|
this controls the number of mels, by default 128. |
|
kwargs : dict, optional |
|
Keyword arguments to :py:func:`audiotools.core.util.format_figure`. |
|
""" |
|
import librosa |
|
import librosa.display |
|
|
|
|
|
|
|
signal = self.clone() |
|
signal.stft_data = None |
|
|
|
if preemphasis: |
|
signal.preemphasis() |
|
|
|
ref = signal.magnitude.max() |
|
log_mag = signal.log_magnitude(ref_value=ref) |
|
|
|
if y_axis == "mel": |
|
log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10() |
|
log_mag -= log_mag.max() |
|
|
|
librosa.display.specshow( |
|
log_mag.numpy()[0].mean(axis=0), |
|
x_axis=x_axis, |
|
y_axis=y_axis, |
|
sr=signal.sample_rate, |
|
**kwargs, |
|
) |
|
|
|
@format_figure |
|
def waveplot(self, x_axis: str = "time", **kwargs): |
|
"""Displays a waveform plot, using ``librosa.display.waveshow``. |
|
|
|
Parameters |
|
---------- |
|
x_axis : str, optional |
|
How to label the x axis, by default "time" |
|
kwargs : dict, optional |
|
Keyword arguments to :py:func:`audiotools.core.util.format_figure`. |
|
""" |
|
import librosa |
|
import librosa.display |
|
|
|
audio_data = self.audio_data[0].mean(dim=0) |
|
audio_data = audio_data.cpu().numpy() |
|
|
|
plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot" |
|
wave_plot_fn = getattr(librosa.display, plot_fn) |
|
wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs) |
|
|
|
@format_figure |
|
def wavespec(self, x_axis: str = "time", **kwargs): |
|
"""Displays a waveform plot, using ``librosa.display.waveshow``. |
|
|
|
Parameters |
|
---------- |
|
x_axis : str, optional |
|
How to label the x axis, by default "time" |
|
kwargs : dict, optional |
|
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`. |
|
""" |
|
import matplotlib.pyplot as plt |
|
from matplotlib.gridspec import GridSpec |
|
|
|
gs = GridSpec(6, 1) |
|
plt.subplot(gs[0, :]) |
|
self.waveplot(x_axis=x_axis) |
|
plt.subplot(gs[1:, :]) |
|
self.specshow(x_axis=x_axis, **kwargs) |
|
|
|
def write_audio_to_tb( |
|
self, |
|
tag: str, |
|
writer, |
|
step: int = None, |
|
plot_fn: typing.Union[typing.Callable, str] = "specshow", |
|
**kwargs, |
|
): |
|
"""Writes a signal and its spectrogram to Tensorboard. Will show up |
|
under the Audio and Images tab in Tensorboard. |
|
|
|
Parameters |
|
---------- |
|
tag : str |
|
Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be |
|
written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``). |
|
writer : SummaryWriter |
|
A SummaryWriter object from PyTorch library. |
|
step : int, optional |
|
The step to write the signal to, by default None |
|
plot_fn : typing.Union[typing.Callable, str], optional |
|
How to create the image. Set to ``None`` to avoid plotting, by default "specshow" |
|
kwargs : dict, optional |
|
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or |
|
whatever ``plot_fn`` is set to. |
|
""" |
|
import matplotlib.pyplot as plt |
|
|
|
audio_data = self.audio_data[0, 0].detach().cpu() |
|
sample_rate = self.sample_rate |
|
writer.add_audio(tag, audio_data, step, sample_rate) |
|
|
|
if plot_fn is not None: |
|
if isinstance(plot_fn, str): |
|
plot_fn = getattr(self, plot_fn) |
|
fig = plt.figure() |
|
plt.clf() |
|
plot_fn(**kwargs) |
|
writer.add_figure(tag.replace("wav", "png"), fig, step) |
|
|
|
def save_image( |
|
self, |
|
image_path: str, |
|
plot_fn: typing.Union[typing.Callable, str] = "specshow", |
|
**kwargs, |
|
): |
|
"""Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to |
|
a specified file. |
|
|
|
Parameters |
|
---------- |
|
image_path : str |
|
Where to save the file to. |
|
plot_fn : typing.Union[typing.Callable, str], optional |
|
How to create the image. Set to ``None`` to avoid plotting, by default "specshow" |
|
kwargs : dict, optional |
|
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or |
|
whatever ``plot_fn`` is set to. |
|
""" |
|
import matplotlib.pyplot as plt |
|
|
|
if isinstance(plot_fn, str): |
|
plot_fn = getattr(self, plot_fn) |
|
|
|
plt.clf() |
|
plot_fn(**kwargs) |
|
plt.savefig(image_path, bbox_inches="tight", pad_inches=0) |
|
plt.close() |
|
|