Spaces:
Build error
Build error
from typing import Any | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
def validate_numpy_array(value: Any): | |
r""" | |
Validates the input and makes sure it returns a numpy array (i.e on CPU) | |
Args: | |
value (Any): the input value | |
Raises: | |
TypeError: if the value is not a numpy array or torch tensor | |
Returns: | |
np.ndarray: numpy array of the value | |
""" | |
if isinstance(value, np.ndarray): | |
pass | |
elif isinstance(value, list): | |
value = np.array(value) | |
elif torch.is_tensor(value): | |
value = value.cpu().numpy() | |
else: | |
raise TypeError("Value must be a numpy array, a torch tensor or a list") | |
return value | |
def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder=None): | |
"""Get the most probable state means from the log_alpha_scaled. | |
Args: | |
log_alpha_scaled (torch.Tensor): Log alpha scaled values. | |
- Shape: :math:`(T, N)` | |
means (torch.Tensor): Means of the states. | |
- Shape: :math:`(N, T, D_out)` | |
decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram. Defaults to None. | |
""" | |
max_state_numbers = torch.max(log_alpha_scaled, dim=1)[1] | |
max_len = means.shape[0] | |
n_mel_channels = means.shape[2] | |
max_state_numbers = max_state_numbers.unsqueeze(1).unsqueeze(1).expand(max_len, 1, n_mel_channels) | |
means = torch.gather(means, 1, max_state_numbers).squeeze(1).to(log_alpha_scaled.dtype) | |
if decoder is not None: | |
mel = ( | |
decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0] | |
.squeeze(0) | |
.T | |
) | |
else: | |
mel = means | |
return mel | |
def plot_transition_probabilities_to_numpy(states, transition_probabilities, output_fig=False): | |
"""Generates trainsition probabilities plot for the states and the probability of transition. | |
Args: | |
states (torch.IntTensor): the states | |
transition_probabilities (torch.FloatTensor): the transition probabilities | |
""" | |
states = validate_numpy_array(states) | |
transition_probabilities = validate_numpy_array(transition_probabilities) | |
fig, ax = plt.subplots(figsize=(30, 3)) | |
ax.plot(transition_probabilities, "o") | |
ax.set_title("Transition probability of state") | |
ax.set_xlabel("hidden state") | |
ax.set_ylabel("probability") | |
ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension | |
ax.set_xticklabels([int(x) for x in states], rotation=90) | |
plt.tight_layout() | |
if not output_fig: | |
plt.close() | |
return fig | |