|
import streamlit as st |
|
import torch |
|
import torchaudio |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from dataclasses import dataclass |
|
import string |
|
import IPython |
|
|
|
|
|
st.write(torch.__version__) |
|
st.write(torchaudio.__version__) |
|
device = 'cpu' |
|
st.write(device) |
|
|
|
|
|
SPEECH_FILE = 'abby_cadabby.wav' |
|
waveform, sample_rate = torchaudio.load(SPEECH_FILE) |
|
st.write(SPEECH_FILE) |
|
|
|
|
|
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H |
|
model = bundle.get_model().to(device) |
|
labels = bundle.get_labels() |
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
waveform, sample_rate = torchaudio.load(SPEECH_FILE) |
|
waveform = waveform.to(device) |
|
|
|
|
|
emissions, _ = model(waveform) |
|
emissions = torch.log_softmax(emissions, dim=-1) |
|
|
|
|
|
emission = emissions[0].cpu().detach() |
|
|
|
|
|
st.write('Labels are: ', labels) |
|
st.write('Length of labels are: ', len(labels)) |
|
|
|
|
|
def plot(): |
|
fig, ax = plt.subplots() |
|
img = ax.imshow(emission.T) |
|
ax.set_title("Frame-wise class probability") |
|
ax.set_xlabel("Time") |
|
ax.set_ylabel("Labels") |
|
fig.colorbar(img, ax=ax, shrink=0.6, location="bottom") |
|
fig.tight_layout() |
|
return fig |
|
|
|
st.pyplot(plot()) |
|
|
|
|
|
def remove_punctuation(input_string): |
|
|
|
translator = str.maketrans('', '', string.punctuation) |
|
|
|
|
|
words = input_string.split() |
|
|
|
|
|
clean_words = ['|' + word.translate(translator).upper() + '|' for word in words] |
|
clean_transcript = ''.join(clean_words).strip('|') |
|
|
|
return clean_transcript |
|
|
|
|
|
transcript = " Oh hi! It's me, Abby Cadabby. Do you want to watch me practice my magic? I am going to turn this" |
|
|
|
clean_transcript = remove_punctuation(transcript) |
|
st.write(clean_transcript) |
|
|
|
|
|
def get_trellis(emission, tokens, blank_id=0): |
|
num_frame = emission.size(0) |
|
num_tokens = len(tokens) |
|
|
|
trellis = torch.zeros((num_frame, num_tokens)) |
|
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) |
|
trellis[0, 1:] = -float("inf") |
|
trellis[-num_tokens + 1 :, 0] = float("inf") |
|
|
|
for t in range(num_frame - 1): |
|
trellis[t + 1, 1:] = torch.maximum( |
|
|
|
trellis[t, 1:] + emission[t, blank_id], |
|
|
|
trellis[t, :-1] + emission[t, tokens[1:]] |
|
) |
|
return trellis |
|
|
|
trellis = get_trellis(emission, tokens) |
|
st.write('Trellis =', trellis) |
|
|
|
|
|
def n_inf_to_p_inf(): |
|
fig, ax = plt.subplots() |
|
img = ax.imshow(trellis.T, origin="lower") |
|
ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5)) |
|
|
|
ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 1.4, trellis.size(1) / 3)) |
|
fig.colorbar(img, ax=ax, shrink=0.25, location="bottom") |
|
fig.tight_layout() |
|
return fig |
|
|
|
st.pyplot(n_inf_to_p_inf()) |
|
|
|
|
|
@dataclass |
|
class Point: |
|
token_index: int |
|
time_index: int |
|
score: float |
|
|
|
def backtrack(trellis, emission, tokens, blank_id=0): |
|
t, j = trellis.size(0) - 1, trellis.size(1) - 1 |
|
|
|
path = [Point(j, t, emission[t, blank_id].exp().item())] |
|
while j > 0: |
|
|
|
assert t > 0 |
|
|
|
|
|
|
|
p_stay = emission[t - 1, blank_id] |
|
p_change = emission[t - 1, tokens[j]] |
|
|
|
|
|
stayed = trellis[t - 1, j] + p_stay |
|
changed = trellis[t - 1, j - 1] + p_change |
|
|
|
|
|
t -= 1 |
|
if changed > stayed: |
|
j -= 1 |
|
|
|
|
|
prob = (p_change if changed > stayed else p_stay).exp().item() |
|
path.append(Point(j, t, prob)) |
|
|
|
|
|
|
|
while t > 0: |
|
prob = emission[t - 1, blank_id].exp().item() |
|
path.append(Point(j, t - 1, prob)) |
|
t -= 1 |
|
return path[::-1] |
|
|
|
path = backtrack(trellis, emission, tokens) |
|
for p in path: |
|
st.write('Token index, Time index and Score:') |
|
st.write(p) |
|
|
|
|
|
def plot_trellis_with_path(trellis, path): |
|
|
|
trellis_with_path = trellis.clone() |
|
for _, p in enumerate(path): |
|
trellis_with_path[p.time_index, p.token_index] = float("nan") |
|
plt.imshow(trellis_with_path.T, origin="lower") |
|
plt.title("The path found by backtracking") |
|
plt.tight_layout() |
|
return plt |
|
|
|
st.pyplot(plot_trellis_with_path(trellis, path)) |
|
|
|
|
|
|
|
@dataclass |
|
class Segment: |
|
label: str |
|
start: int |
|
end: int |
|
score: float |
|
|
|
def __repr__(self): |
|
return f"{self.label}\t({self.score:4.2f}) : [{self.start:5d}, {self.end:5d})" |
|
|
|
@property |
|
def length(self): |
|
return self.end - self.start |
|
|
|
def merge_repeats(path): |
|
i1, i2 = 0, 0 |
|
segments = [] |
|
while i1 < len(path): |
|
while i2 < len(path) and path[i1].token_index == path[i2].token_index: |
|
i2 += 1 |
|
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) |
|
segments.append( |
|
Segment( |
|
transcript[path[i1].token_index], |
|
path[i1].time_index, |
|
path[i2 - 1].time_index + 1, |
|
score, |
|
) |
|
) |
|
i1 = i2 |
|
return segments |
|
|
|
segments = merge_repeats(path) |
|
for seg in segments: |
|
st.write('Segments:') |
|
st.write(seg) |
|
|
|
|
|
def plot_trellis_with_segments(trellis, segments, transcript): |
|
|
|
trellis_with_path = trellis.clone() |
|
for i, seg in enumerate(segments): |
|
if seg.label != "|": |
|
trellis_with_path[seg.start : seg.end, i] = float("nan") |
|
|
|
fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True, figsize=(15, 15)) |
|
ax1.set_title("Path, label and probability for each label") |
|
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") |
|
|
|
|
|
for i, seg in enumerate(segments): |
|
if seg.label != "|": |
|
ax1.annotate(seg.label, (seg.start, i - 0.3), size="small") |
|
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 0.3), size="small") |
|
|
|
ax2.set_title("Label probability with and without repetition") |
|
xs, hs, ws = [], [], [] |
|
for seg in segments: |
|
if seg.label != "|": |
|
xs.append((seg.end + seg.start) / 2 + 0.4) |
|
hs.append(seg.score) |
|
ws.append(seg.end - seg.start) |
|
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), rotation=0) |
|
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.9, edgecolor="black") |
|
|
|
xs, hs = [], [] |
|
for p in path: |
|
label = transcript[p.token_index] |
|
if label != "|": |
|
xs.append(p.time_index + 1) |
|
hs.append(p.score) |
|
|
|
ax2.bar(xs, hs, width=0.9, alpha=0.9) |
|
ax2.axhline(0, color="black") |
|
ax2.grid(True, axis="y") |
|
ax2.set_ylim(-0.1, 1.1) |
|
fig.tight_layout() |
|
return fig |
|
|
|
|
|
plot_trellis_with_segments(trellis, segments, clean_transcript) |
|
st.pyplot(plot_trellis_with_segments(trellis, segments, clean_transcript)) |
|
|
|
|
|
|
|
def merge_words(segments, separator="|"): |
|
words = [] |
|
i1, i2 = 0, 0 |
|
while i1 < len(segments): |
|
if i2 >= len(segments) or segments[i2].label == separator: |
|
if i1 != i2: |
|
segs = segments[i1:i2] |
|
word = "".join([seg.label for seg in segs]) |
|
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) |
|
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) |
|
i1 = i2 + 1 |
|
i2 = i1 |
|
else: |
|
i2 += 1 |
|
return words |
|
|
|
|
|
word_segments = merge_words(segments) |
|
for word in word_segments: |
|
st.write('Word Segments:') |
|
st.write(word) |
|
|
|
|
|
def plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=44100): |
|
trellis_with_path = trellis.clone() |
|
for i, seg in enumerate(segments): |
|
if seg.label != "|": |
|
trellis_with_path[seg.start : seg.end, i] = float("nan") |
|
|
|
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(20, 18)) |
|
|
|
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") |
|
ax1.set_facecolor("lightgray") |
|
ax1.set_xticks([]) |
|
ax1.set_yticks([]) |
|
|
|
for word in word_segments: |
|
ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none") |
|
|
|
for i, seg in enumerate(segments): |
|
if seg.label != "|": |
|
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small") |
|
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small") |
|
|
|
|
|
NFFT = 1024 |
|
|
|
|
|
ratio = len(waveform) / sample_rate / trellis.size(0) |
|
|
|
ax2.specgram(waveform, Fs=sample_rate, NFFT=NFFT) |
|
for word in word_segments: |
|
x0 = ratio * word.start |
|
x1 = ratio * word.end |
|
ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/") |
|
ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False) |
|
|
|
for seg in segments: |
|
if seg.label != "|": |
|
ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False) |
|
ax2.set_xlabel("time [second]") |
|
ax2.set_yticks([]) |
|
fig.tight_layout() |
|
return fig |
|
|
|
|
|
plot_alignments(trellis, segments, word_segments, waveform, sample_rate) |
|
st.pyplot(plot_alignments(trellis, word_segments, waveform, sample_rate)) |
|
|
|
|
|
def display_segment(i): |
|
ratio = waveform.size(1) / trellis.size(0) |
|
word = word_segments[i] |
|
x0 = int(ratio * word.start) |
|
x1 = int(ratio * word.end) |
|
print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec") |
|
segment = waveform[:, x0:x1] |
|
return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate) |
|
|
|
|
|
st.write('Abby Cadabby Transcript:') |
|
st.write('Transcript') |
|
st.write(IPython.display.Audio(SPEECH_FILE)) |
|
|