TroglodyteDerivations
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
from dataclasses import dataclass
|
7 |
+
import string
|
8 |
+
import IPython
|
9 |
+
|
10 |
+
# Part A: Import torch and torchaudio
|
11 |
+
st.write(torch.__version__)
|
12 |
+
st.write(torchaudio.__version__)
|
13 |
+
device = 'cpu'
|
14 |
+
st.write(device)
|
15 |
+
|
16 |
+
# Part B: Load the audio file
|
17 |
+
SPEECH_FILE = 'abby_cadabby.wav'
|
18 |
+
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
|
19 |
+
st.write(SPEECH_FILE)
|
20 |
+
|
21 |
+
# Part C: torchaudio.pipelines | bundle.get_model | bundle.get_labels()
|
22 |
+
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
|
23 |
+
model = bundle.get_model().to(device)
|
24 |
+
labels = bundle.get_labels()
|
25 |
+
|
26 |
+
# Inference mode
|
27 |
+
with torch.inference_mode():
|
28 |
+
# Load the audio file using torchaudio.load
|
29 |
+
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
|
30 |
+
waveform = waveform.to(device)
|
31 |
+
|
32 |
+
# Pass the waveform through the model
|
33 |
+
emissions, _ = model(waveform)
|
34 |
+
emissions = torch.log_softmax(emissions, dim=-1)
|
35 |
+
|
36 |
+
# Get the emissions for the first example
|
37 |
+
emission = emissions[0].cpu().detach()
|
38 |
+
|
39 |
+
# Print the labels
|
40 |
+
st.write('Labels are: ', labels)
|
41 |
+
st.write('Length of labels are: ', len(labels))
|
42 |
+
|
43 |
+
# Part D: Frame-wise class probability plot
|
44 |
+
def plot():
|
45 |
+
fig, ax = plt.subplots()
|
46 |
+
img = ax.imshow(emission.T)
|
47 |
+
ax.set_title("Frame-wise class probability")
|
48 |
+
ax.set_xlabel("Time")
|
49 |
+
ax.set_ylabel("Labels")
|
50 |
+
fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
|
51 |
+
fig.tight_layout()
|
52 |
+
return fig
|
53 |
+
|
54 |
+
st.pyplot(plot())
|
55 |
+
|
56 |
+
# Part E: Remove punctuation add | after each word. Also, convert into all UPPERCASE
|
57 |
+
def remove_punctuation(input_string):
|
58 |
+
# Make a translator object to remove all punctuation
|
59 |
+
translator = str.maketrans('', '', string.punctuation)
|
60 |
+
|
61 |
+
# Split the input string into words
|
62 |
+
words = input_string.split()
|
63 |
+
|
64 |
+
# Remove punctuation from each word, convert to uppercase, and join them with '|'
|
65 |
+
clean_words = ['|' + word.translate(translator).upper() + '|' for word in words]
|
66 |
+
clean_transcript = ''.join(clean_words).strip('|')
|
67 |
+
|
68 |
+
return clean_transcript
|
69 |
+
|
70 |
+
# Test the function
|
71 |
+
transcript = " Oh hi! It's me, Abby Cadabby. Do you want to watch me practice my magic? I am going to turn this"
|
72 |
+
|
73 |
+
clean_transcript = remove_punctuation(transcript)
|
74 |
+
st.write(clean_transcript)
|
75 |
+
|
76 |
+
# Part F: Populate Trellis
|
77 |
+
def get_trellis(emission, tokens, blank_id=0):
|
78 |
+
num_frame = emission.size(0)
|
79 |
+
num_tokens = len(tokens)
|
80 |
+
|
81 |
+
trellis = torch.zeros((num_frame, num_tokens))
|
82 |
+
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
|
83 |
+
trellis[0, 1:] = -float("inf")
|
84 |
+
trellis[-num_tokens + 1 :, 0] = float("inf")
|
85 |
+
|
86 |
+
for t in range(num_frame - 1):
|
87 |
+
trellis[t + 1, 1:] = torch.maximum(
|
88 |
+
# Score for staying at the same token
|
89 |
+
trellis[t, 1:] + emission[t, blank_id],
|
90 |
+
# Score for changing to the next token
|
91 |
+
trellis[t, :-1] + emission[t, tokens[1:]]
|
92 |
+
)
|
93 |
+
return trellis
|
94 |
+
|
95 |
+
trellis = get_trellis(emission, tokens)
|
96 |
+
st.write('Trellis =', trellis)
|
97 |
+
|
98 |
+
# Part G: Labels and Time -Inf | +Inf
|
99 |
+
def n_inf_to_p_inf():
|
100 |
+
fig, ax = plt.subplots()
|
101 |
+
img = ax.imshow(trellis.T, origin="lower")
|
102 |
+
ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
|
103 |
+
# Shift the "+ Inf" annotation to the left by decreasing the x-coordinate value
|
104 |
+
ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 1.4, trellis.size(1) / 3))
|
105 |
+
fig.colorbar(img, ax=ax, shrink=0.25, location="bottom")
|
106 |
+
fig.tight_layout()
|
107 |
+
return fig
|
108 |
+
|
109 |
+
st.pyplot(n_inf_to_p_inf())
|
110 |
+
|
111 |
+
# Part H: Backtrack Trellis Emissions Tensor and Tokens
|
112 |
+
@dataclass
|
113 |
+
class Point:
|
114 |
+
token_index: int
|
115 |
+
time_index: int
|
116 |
+
score: float
|
117 |
+
|
118 |
+
def backtrack(trellis, emission, tokens, blank_id=0):
|
119 |
+
t, j = trellis.size(0) - 1, trellis.size(1) - 1
|
120 |
+
|
121 |
+
path = [Point(j, t, emission[t, blank_id].exp().item())]
|
122 |
+
while j > 0:
|
123 |
+
# Should not happen but just in case
|
124 |
+
assert t > 0
|
125 |
+
|
126 |
+
# 1. Figure out if the current position was stay or change
|
127 |
+
# Frame-wise score of stay vs change
|
128 |
+
p_stay = emission[t - 1, blank_id]
|
129 |
+
p_change = emission[t - 1, tokens[j]]
|
130 |
+
|
131 |
+
# Context-aware score for stay vs change
|
132 |
+
stayed = trellis[t - 1, j] + p_stay
|
133 |
+
changed = trellis[t - 1, j - 1] + p_change
|
134 |
+
|
135 |
+
# Update position
|
136 |
+
t -= 1
|
137 |
+
if changed > stayed:
|
138 |
+
j -= 1
|
139 |
+
|
140 |
+
# Store the path with frame-wise probability
|
141 |
+
prob = (p_change if changed > stayed else p_stay).exp().item()
|
142 |
+
path.append(Point(j, t, prob))
|
143 |
+
|
144 |
+
# Now j == 0, which means, it reached the SOS.
|
145 |
+
# Fill up the rest for the sake of visualization
|
146 |
+
while t > 0:
|
147 |
+
prob = emission[t - 1, blank_id].exp().item()
|
148 |
+
path.append(Point(j, t - 1, prob))
|
149 |
+
t -= 1
|
150 |
+
return path[::-1]
|
151 |
+
|
152 |
+
path = backtrack(trellis, emission, tokens)
|
153 |
+
for p in path:
|
154 |
+
st.write('Token index, Time index and Score:')
|
155 |
+
st.write(p)
|
156 |
+
|
157 |
+
# Part I: Trellis with Path Visualization
|
158 |
+
def plot_trellis_with_path(trellis, path):
|
159 |
+
# To plot trellis with path, we take advantage of 'nan' value
|
160 |
+
trellis_with_path = trellis.clone()
|
161 |
+
for _, p in enumerate(path):
|
162 |
+
trellis_with_path[p.time_index, p.token_index] = float("nan")
|
163 |
+
plt.imshow(trellis_with_path.T, origin="lower")
|
164 |
+
plt.title("The path found by backtracking")
|
165 |
+
plt.tight_layout()
|
166 |
+
return plt
|
167 |
+
|
168 |
+
st.pyplot(plot_trellis_with_path(trellis, path))
|
169 |
+
|
170 |
+
# Part J: Merge Repeats | Segments
|
171 |
+
# Merge the labels
|
172 |
+
@dataclass
|
173 |
+
class Segment:
|
174 |
+
label: str
|
175 |
+
start: int
|
176 |
+
end: int
|
177 |
+
score: float
|
178 |
+
|
179 |
+
def __repr__(self):
|
180 |
+
return f"{self.label}\t({self.score:4.2f}) : [{self.start:5d}, {self.end:5d})"
|
181 |
+
|
182 |
+
@property
|
183 |
+
def length(self):
|
184 |
+
return self.end - self.start
|
185 |
+
|
186 |
+
def merge_repeats(path):
|
187 |
+
i1, i2 = 0, 0
|
188 |
+
segments = []
|
189 |
+
while i1 < len(path):
|
190 |
+
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
|
191 |
+
i2 += 1
|
192 |
+
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
|
193 |
+
segments.append(
|
194 |
+
Segment(
|
195 |
+
transcript[path[i1].token_index],
|
196 |
+
path[i1].time_index,
|
197 |
+
path[i2 - 1].time_index + 1,
|
198 |
+
score,
|
199 |
+
)
|
200 |
+
)
|
201 |
+
i1 = i2
|
202 |
+
return segments
|
203 |
+
|
204 |
+
segments = merge_repeats(path)
|
205 |
+
for seg in segments:
|
206 |
+
st.write('Segments:')
|
207 |
+
st.write(seg)
|
208 |
+
|
209 |
+
# Part K: Trellis with Segments Visualization
|
210 |
+
def plot_trellis_with_segments(trellis, segments, transcript):
|
211 |
+
# To plot trellis with path, we take advantage of 'nan' value
|
212 |
+
trellis_with_path = trellis.clone()
|
213 |
+
for i, seg in enumerate(segments):
|
214 |
+
if seg.label != "|":
|
215 |
+
trellis_with_path[seg.start : seg.end, i] = float("nan")
|
216 |
+
|
217 |
+
fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True, figsize=(15, 15))
|
218 |
+
ax1.set_title("Path, label and probability for each label")
|
219 |
+
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
|
220 |
+
|
221 |
+
# Adjust the position of the annotations to spread them out
|
222 |
+
for i, seg in enumerate(segments):
|
223 |
+
if seg.label != "|":
|
224 |
+
ax1.annotate(seg.label, (seg.start, i - 0.3), size="small")
|
225 |
+
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 0.3), size="small")
|
226 |
+
|
227 |
+
ax2.set_title("Label probability with and without repetition")
|
228 |
+
xs, hs, ws = [], [], []
|
229 |
+
for seg in segments:
|
230 |
+
if seg.label != "|":
|
231 |
+
xs.append((seg.end + seg.start) / 2 + 0.4)
|
232 |
+
hs.append(seg.score)
|
233 |
+
ws.append(seg.end - seg.start)
|
234 |
+
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), rotation=0)
|
235 |
+
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.9, edgecolor="black")
|
236 |
+
|
237 |
+
xs, hs = [], []
|
238 |
+
for p in path:
|
239 |
+
label = transcript[p.token_index]
|
240 |
+
if label != "|":
|
241 |
+
xs.append(p.time_index + 1)
|
242 |
+
hs.append(p.score)
|
243 |
+
|
244 |
+
ax2.bar(xs, hs, width=0.9, alpha=0.9)
|
245 |
+
ax2.axhline(0, color="black")
|
246 |
+
ax2.grid(True, axis="y")
|
247 |
+
ax2.set_ylim(-0.1, 1.1)
|
248 |
+
fig.tight_layout()
|
249 |
+
return fig
|
250 |
+
|
251 |
+
|
252 |
+
plot_trellis_with_segments(trellis, segments, clean_transcript)
|
253 |
+
st.pyplot(plot_trellis_with_segments(trellis, segments, clean_transcript))
|
254 |
+
|
255 |
+
# Part L: Merge words | Segments
|
256 |
+
# Merge words
|
257 |
+
def merge_words(segments, separator="|"):
|
258 |
+
words = []
|
259 |
+
i1, i2 = 0, 0
|
260 |
+
while i1 < len(segments):
|
261 |
+
if i2 >= len(segments) or segments[i2].label == separator:
|
262 |
+
if i1 != i2:
|
263 |
+
segs = segments[i1:i2]
|
264 |
+
word = "".join([seg.label for seg in segs])
|
265 |
+
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
|
266 |
+
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
|
267 |
+
i1 = i2 + 1
|
268 |
+
i2 = i1
|
269 |
+
else:
|
270 |
+
i2 += 1
|
271 |
+
return words
|
272 |
+
|
273 |
+
|
274 |
+
word_segments = merge_words(segments)
|
275 |
+
for word in word_segments:
|
276 |
+
st.write('Word Segments:')
|
277 |
+
st.write(word)
|
278 |
+
|
279 |
+
# Part M: Alignment Visualizations
|
280 |
+
def plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=44100):
|
281 |
+
trellis_with_path = trellis.clone()
|
282 |
+
for i, seg in enumerate(segments):
|
283 |
+
if seg.label != "|":
|
284 |
+
trellis_with_path[seg.start : seg.end, i] = float("nan")
|
285 |
+
|
286 |
+
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(20, 18))
|
287 |
+
|
288 |
+
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
|
289 |
+
ax1.set_facecolor("lightgray")
|
290 |
+
ax1.set_xticks([])
|
291 |
+
ax1.set_yticks([])
|
292 |
+
|
293 |
+
for word in word_segments:
|
294 |
+
ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none")
|
295 |
+
|
296 |
+
for i, seg in enumerate(segments):
|
297 |
+
if seg.label != "|":
|
298 |
+
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
|
299 |
+
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
|
300 |
+
|
301 |
+
# The original waveform
|
302 |
+
NFFT = 1024
|
303 |
+
#ratio = waveform.size(0) / sample_rate / trellis.size(0)
|
304 |
+
#ratio = len(waveform) / sample_rate / trellis.size(0)
|
305 |
+
ratio = len(waveform) / sample_rate / trellis.size(0) #-> populates both visualizations
|
306 |
+
|
307 |
+
ax2.specgram(waveform, Fs=sample_rate, NFFT=NFFT)
|
308 |
+
for word in word_segments:
|
309 |
+
x0 = ratio * word.start
|
310 |
+
x1 = ratio * word.end
|
311 |
+
ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/")
|
312 |
+
ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False)
|
313 |
+
|
314 |
+
for seg in segments:
|
315 |
+
if seg.label != "|":
|
316 |
+
ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False)
|
317 |
+
ax2.set_xlabel("time [second]")
|
318 |
+
ax2.set_yticks([])
|
319 |
+
fig.tight_layout()
|
320 |
+
return fig
|
321 |
+
|
322 |
+
|
323 |
+
plot_alignments(trellis, segments, word_segments, waveform, sample_rate)
|
324 |
+
st.pyplot(plot_alignments(trellis, word_segments, waveform, sample_rate))
|
325 |
+
|
326 |
+
# Part N: Display Segment
|
327 |
+
def display_segment(i):
|
328 |
+
ratio = waveform.size(1) / trellis.size(0)
|
329 |
+
word = word_segments[i]
|
330 |
+
x0 = int(ratio * word.start)
|
331 |
+
x1 = int(ratio * word.end)
|
332 |
+
print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec")
|
333 |
+
segment = waveform[:, x0:x1]
|
334 |
+
return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate)
|
335 |
+
|
336 |
+
# Part O: Audio generation for each segment
|
337 |
+
st.write('Abby Cadabby Transcript:')
|
338 |
+
st.write('Transcript')
|
339 |
+
st.write(IPython.display.Audio(SPEECH_FILE))
|