Flux9665 commited on
Commit
014393d
·
verified ·
1 Parent(s): 39048c1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+
8
+ from ArticulatoryTextFrontend import ArticulatoryTextFrontend
9
+
10
+
11
+ def visualize_one_hot_encoded_sequence(tensor, sentence, col_labels, cmap='BuGn'):
12
+ """
13
+ Visualize a 2D one-hot encoded tensor as a heatmap.
14
+ """
15
+ tensor = torch.clamp(tensor, min=0, max=1).transpose(0, 1).cpu().numpy()
16
+ if tensor.ndim != 2:
17
+ raise ValueError("Input tensor must be a 2D array")
18
+
19
+ # Check the size of labels matches the tensor dimensions
20
+ row_labels = ["stressed", "very-high-tone", "high-tone", "mid-tone", "low-tone", "very-low-tone", "rising-tone", "falling-tone", "peaking-tone", "dipping-tone", "lengthened", "half-length", "shortened", "consonant", "vowel", "phoneme", "silence", "end of sentence", "questionmark", "exclamationmark", "fullstop", "word-boundary", "dental", "postalveolar",
21
+ "velar", "palatal", "glottal", "uvular", "labiodental", "labial-velar", "alveolar", "bilabial", "alveolopalatal", "retroflex", "pharyngal", "epiglottal", "central", "back", "front_central", "front", "central_back", "mid", "close-mid", "close", "open-mid", "close_close-mid", "open-mid_open", "open", "rounded", "unrounded", "plosive",
22
+ "nasal", "approximant", "trill", "flap", "fricative", "lateral-approximant", "implosive", "vibrant", "click", "ejective", "aspirated", "unvoiced", "voiced"]
23
+
24
+ if row_labels and len(row_labels) != tensor.shape[0]:
25
+ raise ValueError("Number of row labels must match the number of rows in the tensor")
26
+ if col_labels and len(col_labels) != tensor.shape[1]:
27
+ raise ValueError("Number of column labels must match the number of columns in the tensor")
28
+
29
+ fig, ax = plt.subplots(figsize=(16, 16))
30
+
31
+ # Create the heatmap
32
+ ax.imshow(tensor, cmap=cmap, aspect='auto')
33
+
34
+ # Add labels
35
+ if row_labels:
36
+ ax.set_yticks(np.arange(tensor.shape[0]), row_labels)
37
+ if col_labels:
38
+ ax.set_xticks(np.arange(tensor.shape[1]), col_labels, rotation=0)
39
+
40
+ ax.grid(False)
41
+ ax.set_xlabel('Phones')
42
+ ax.set_ylabel('Features')
43
+
44
+ # Display the heatmap
45
+ ax.set_title(f"»{sentence}«")
46
+ return fig
47
+
48
+
49
+ def vis_wrapper(sentence, language):
50
+ tf = ArticulatoryTextFrontend(language=language.split(" ")[-1].split("(")[1].split(")")[0])
51
+ features = tf.string_to_tensor(sentence)
52
+ phones = tf.get_phone_string(sentence)
53
+
54
+ return visualize_one_hot_encoded_sequence(tensor=features, sentence=sentence, col_labels=phones)
55
+
56
+
57
+ def load_json_from_path(path):
58
+ with open(path, "r", encoding="utf8") as f:
59
+ obj = json.loads(f.read())
60
+
61
+ return obj
62
+
63
+
64
+ iso_to_name = load_json_from_path("iso_to_fullname.json")
65
+ text_selection = [f"{iso_to_name[iso_code]} ({iso_code})" for iso_code in iso_to_name]
66
+ iface = gr.Interface(fn=vis_wrapper,
67
+ inputs=[gr.Textbox(lines=2,
68
+ placeholder="write the sentence you want to visualize here...",
69
+ value="What I cannot create, I do not understand.",
70
+ label="Text input"),
71
+ gr.Dropdown(text_selection,
72
+ type="value",
73
+ value='English (eng)',
74
+ label="Select the Language of the Text (type on your keyboard to find it quickly)")],
75
+ outputs=[gr.Plot()],
76
+ allow_flagging="never",
77
+ live=False,
78
+ fill_width=True)
79
+ iface.launch()