Spaces:
Running
Running
File size: 2,570 Bytes
606f0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import matplotlib.pyplot as plt
import numpy as np
import torch
from ArticulatoryTextFrontend import ArticulatoryTextFrontend
def visualize_one_hot_encoded_sequence(tensor, sentence, col_labels, cmap='BuGn'):
"""
Visualize a 2D one-hot encoded tensor as a heatmap.
"""
tensor = torch.clamp(tensor, min=0, max=1).transpose(0, 1).cpu().numpy()
if tensor.ndim != 2:
raise ValueError("Input tensor must be a 2D array")
# Check the size of labels matches the tensor dimensions
row_labels = ["stressed", "very-high-tone", "high-tone", "mid-tone", "low-tone", "very-low-tone", "rising-tone", "falling-tone", "peaking-tone", "dipping-tone", "lengthened", "half-length", "shortened", "consonant", "vowel", "phoneme", "silence", "end of sentence", "questionmark", "exclamationmark", "fullstop", "word-boundary", "dental", "postalveolar",
"velar", "palatal", "glottal", "uvular", "labiodental", "labial-velar", "alveolar", "bilabial", "alveolopalatal", "retroflex", "pharyngal", "epiglottal", "central", "back", "front_central", "front", "central_back", "mid", "close-mid", "close", "open-mid", "close_close-mid", "open-mid_open", "open", "rounded", "unrounded", "plosive",
"nasal", "approximant", "trill", "flap", "fricative", "lateral-approximant", "implosive", "vibrant", "click", "ejective", "aspirated", "unvoiced", "voiced"]
if row_labels and len(row_labels) != tensor.shape[0]:
raise ValueError("Number of row labels must match the number of rows in the tensor")
if col_labels and len(col_labels) != tensor.shape[1]:
raise ValueError("Number of column labels must match the number of columns in the tensor")
plt.figure(figsize=(10, 8))
# Create the heatmap
plt.imshow(tensor, cmap=cmap, aspect='auto')
# Add labels
if row_labels:
plt.yticks(np.arange(tensor.shape[0]), row_labels)
if col_labels:
plt.xticks(np.arange(tensor.shape[1]), col_labels, rotation=0)
plt.grid(False)
plt.xlabel('Phones')
plt.ylabel('Features')
# Display the heatmap
plt.title(f"»{sentence}«")
plt.tight_layout()
plt.show()
if __name__ == '__main__':
sentence = "Rằng: Trong Thánh trạch dồi dào."
language = "vie"
tf = ArticulatoryTextFrontend(language=language)
features = tf.string_to_tensor(sentence)
phones = tf.get_phone_string(sentence)
visualize_one_hot_encoded_sequence(tensor=features, sentence=sentence, col_labels=phones)
|