Updated line 112 with: ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 2.4 , trellis.size(1) / 3)) | Updated lines 260-261 with: updated_clean_UPPERCASE_transcript | Added Abby_and_Prince.jpg
3eb37e3
verified
import streamlit as st | |
import torch | |
import torchaudio | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from dataclasses import dataclass | |
import string | |
import IPython | |
st.image('abby_cadabby_.jpeg') | |
# Part A: Import torch and torchaudio | |
st.write(torch.__version__) | |
st.write(torchaudio.__version__) | |
device = 'cpu' | |
st.write(device) | |
# Part B: Load the audio file | |
SPEECH_FILE = 'abby_cadabby.wav' | |
waveform, sample_rate = torchaudio.load(SPEECH_FILE) | |
st.write(SPEECH_FILE) | |
# Part C: torchaudio.pipelines | bundle.get_model | bundle.get_labels() | |
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H | |
model = bundle.get_model().to(device) | |
labels = bundle.get_labels() | |
# Inference mode | |
with torch.inference_mode(): | |
# Load the audio file using torchaudio.load | |
waveform, sample_rate = torchaudio.load(SPEECH_FILE) | |
waveform = waveform.to(device) | |
# Pass the waveform through the model | |
emissions, _ = model(waveform) | |
emissions = torch.log_softmax(emissions, dim=-1) | |
# Get the emissions for the first example | |
emission = emissions[0].cpu().detach() | |
# Print the labels | |
st.write('Labels are: ', labels) | |
st.write('Length of labels are: ', len(labels)) | |
# Part D: Frame-wise class probability plot | |
def plot(): | |
fig, ax = plt.subplots() | |
img = ax.imshow(emission.T) | |
ax.set_title("Frame-wise class probability") | |
ax.set_xlabel("Time") | |
ax.set_ylabel("Labels") | |
fig.colorbar(img, ax=ax, shrink=0.6, location="bottom") | |
fig.tight_layout() | |
return fig | |
st.pyplot(plot()) | |
# Part E: Remove punctuation add | after each word. Also, convert into all UPPERCASE | |
def remove_punctuation(input_string): | |
# Make a translator object to remove all punctuation | |
translator = str.maketrans('', '', string.punctuation) | |
# Split the input string into words | |
words = input_string.split() | |
# Remove punctuation from each word, convert to uppercase, and join them with '|' | |
clean_words = ['|' + word.translate(translator).upper() + '|' for word in words] | |
clean_transcript = ''.join(clean_words).strip('|') | |
return clean_transcript | |
# Test the function | |
transcript = " Oh hi! It's me, Abby Cadabby. Do you want to watch me practice my magic? I am going to turn this" | |
clean_transcript = remove_punctuation(transcript) | |
st.write(clean_transcript) | |
# Part F: Populate Trellis | |
updated_clean_UPPERCASE_transcript = "OH||HI||ITS||ME||ABBY||CADABBY||DO||YOU||WANT||TO||WATCH||ME||PRACTICE||MY||MAGIC||I||AM||GOING||TO||TURN||THIS" | |
dictionary = {c: i for i, c in enumerate(labels)} | |
tokens = [dictionary[c] for c in updated_clean_UPPERCASE_transcript] | |
st.write(list(zip(updated_clean_UPPERCASE_transcript, tokens))) | |
def get_trellis(emission, tokens, blank_id=0): | |
num_frame = emission.size(0) | |
num_tokens = len(tokens) | |
trellis = torch.zeros((num_frame, num_tokens)) | |
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) | |
trellis[0, 1:] = -float("inf") | |
trellis[-num_tokens + 1 :, 0] = float("inf") | |
for t in range(num_frame - 1): | |
trellis[t + 1, 1:] = torch.maximum( | |
# Score for staying at the same token | |
trellis[t, 1:] + emission[t, blank_id], | |
# Score for changing to the next token | |
trellis[t, :-1] + emission[t, tokens[1:]] | |
) | |
return trellis | |
trellis = get_trellis(emission, tokens) | |
st.write('Trellis =', trellis) | |
# Part G: Labels and Time -Inf | +Inf | |
def n_inf_to_p_inf(): | |
fig, ax = plt.subplots() | |
img = ax.imshow(trellis.T, origin="lower") | |
ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5)) | |
# Shift the "+ Inf" annotation to the right by increasing the denominator | |
ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 2.4 , trellis.size(1) / 3)) | |
fig.colorbar(img, ax=ax, shrink=0.25, location="bottom") | |
fig.tight_layout() | |
return fig | |
st.pyplot(n_inf_to_p_inf()) | |
# Part H: Backtrack Trellis Emissions Tensor and Tokens | |
class Point: | |
token_index: int | |
time_index: int | |
score: float | |
def backtrack(trellis, emission, tokens, blank_id=0): | |
t, j = trellis.size(0) - 1, trellis.size(1) - 1 | |
path = [Point(j, t, emission[t, blank_id].exp().item())] | |
while j > 0: | |
# Should not happen but just in case | |
assert t > 0 | |
# 1. Figure out if the current position was stay or change | |
# Frame-wise score of stay vs change | |
p_stay = emission[t - 1, blank_id] | |
p_change = emission[t - 1, tokens[j]] | |
# Context-aware score for stay vs change | |
stayed = trellis[t - 1, j] + p_stay | |
changed = trellis[t - 1, j - 1] + p_change | |
# Update position | |
t -= 1 | |
if changed > stayed: | |
j -= 1 | |
# Store the path with frame-wise probability | |
prob = (p_change if changed > stayed else p_stay).exp().item() | |
path.append(Point(j, t, prob)) | |
# Now j == 0, which means, it reached the SOS. | |
# Fill up the rest for the sake of visualization | |
while t > 0: | |
prob = emission[t - 1, blank_id].exp().item() | |
path.append(Point(j, t - 1, prob)) | |
t -= 1 | |
return path[::-1] | |
path = backtrack(trellis, emission, tokens) | |
for p in path: | |
st.write('Token index, Time index and Score:') | |
st.write(p) | |
# Part I: Trellis with Path Visualization | |
def plot_trellis_with_path(trellis, path): | |
# To plot trellis with path, we take advantage of 'nan' value | |
trellis_with_path = trellis.clone() | |
for _, p in enumerate(path): | |
trellis_with_path[p.time_index, p.token_index] = float("nan") | |
plt.imshow(trellis_with_path.T, origin="lower") | |
plt.title("The path found by backtracking") | |
plt.tight_layout() | |
return plt | |
st.pyplot(plot_trellis_with_path(trellis, path)) | |
# Part J: Merge Repeats | Segments | |
# Merge the labels | |
class Segment: | |
label: str | |
start: int | |
end: int | |
score: float | |
def __repr__(self): | |
return f"{self.label}\t({self.score:4.2f}) : [{self.start:5d}, {self.end:5d})" | |
def length(self): | |
return self.end - self.start | |
def merge_repeats(path): | |
i1, i2 = 0, 0 | |
segments = [] | |
while i1 < len(path): | |
while i2 < len(path) and path[i1].token_index == path[i2].token_index: | |
i2 += 1 | |
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) | |
segments.append( | |
Segment( | |
updated_clean_UPPERCASE_transcript[path[i1].token_index], | |
path[i1].time_index, | |
path[i2 - 1].time_index + 1, | |
score, | |
) | |
) | |
i1 = i2 | |
return segments | |
segments = merge_repeats(path) | |
for seg in segments: | |
st.write('Segments:') | |
st.write(seg) | |
# Part K: Trellis with Segments Visualization | |
def plot_trellis_with_segments(trellis, segments, transcript): | |
# To plot trellis with path, we take advantage of 'nan' value | |
trellis_with_path = trellis.clone() | |
for i, seg in enumerate(segments): | |
if seg.label != "|": | |
trellis_with_path[seg.start : seg.end, i] = float("nan") | |
fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True, figsize=(15, 15)) | |
ax1.set_title("Path, label and probability for each label") | |
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") | |
# Adjust the position of the annotations to spread them out | |
for i, seg in enumerate(segments): | |
if seg.label != "|": | |
ax1.annotate(seg.label, (seg.start, i - 0.3), size="small") | |
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 0.3), size="small") | |
ax2.set_title("Label probability with and without repetition") | |
xs, hs, ws = [], [], [] | |
for seg in segments: | |
if seg.label != "|": | |
xs.append((seg.end + seg.start) / 2 + 0.4) | |
hs.append(seg.score) | |
ws.append(seg.end - seg.start) | |
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), rotation=0) | |
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.9, edgecolor="black") | |
xs, hs = [], [] | |
for p in path: | |
label = transcript[p.token_index] | |
if label != "|": | |
xs.append(p.time_index + 1) | |
hs.append(p.score) | |
ax2.bar(xs, hs, width=0.9, alpha=0.9) | |
ax2.axhline(0, color="black") | |
ax2.grid(True, axis="y") | |
ax2.set_ylim(-0.1, 1.1) | |
fig.tight_layout() | |
return fig | |
plot_trellis_with_segments(trellis, segments, updated_clean_UPPERCASE_transcript) | |
st.pyplot(plot_trellis_with_segments(trellis, segments, updated_clean_UPPERCASE_transcript)) | |
# Part L: Merge words | Segments | |
# Merge words | |
def merge_words(segments, separator="|"): | |
words = [] | |
i1, i2 = 0, 0 | |
while i1 < len(segments): | |
if i2 >= len(segments) or segments[i2].label == separator: | |
if i1 != i2: | |
segs = segments[i1:i2] | |
word = "".join([seg.label for seg in segs]) | |
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) | |
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) | |
i1 = i2 + 1 | |
i2 = i1 | |
else: | |
i2 += 1 | |
return words | |
word_segments = merge_words(segments) | |
for word in word_segments: | |
st.write('Word Segments:') | |
st.write(word) | |
# Part M: Alignment Visualizations | |
def plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=44100): | |
trellis_with_path = trellis.clone() | |
for i, seg in enumerate(segments): | |
if seg.label != "|": | |
trellis_with_path[seg.start : seg.end, i] = float("nan") | |
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(20, 18)) | |
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") | |
ax1.set_facecolor("lightgray") | |
ax1.set_xticks([]) | |
ax1.set_yticks([]) | |
for word in word_segments: | |
ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none") | |
for i, seg in enumerate(segments): | |
if seg.label != "|": | |
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small") | |
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small") | |
# The original waveform | |
NFFT = 1024 | |
#ratio = waveform.size(0) / sample_rate / trellis.size(0) | |
#ratio = len(waveform) / sample_rate / trellis.size(0) | |
ratio = len(waveform) / sample_rate / trellis.size(0) #-> populates both visualizations | |
ax2.specgram(waveform, Fs=sample_rate, NFFT=NFFT) | |
for word in word_segments: | |
x0 = ratio * word.start | |
x1 = ratio * word.end | |
ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/") | |
ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False) | |
for seg in segments: | |
if seg.label != "|": | |
ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False) | |
ax2.set_xlabel("time [second]") | |
ax2.set_yticks([]) | |
fig.tight_layout() | |
return fig | |
plot_alignments(trellis, segments, word_segments, waveform, sample_rate) | |
st.pyplot(plot_alignments(trellis, word_segments, waveform, sample_rate)) | |
# Part N: Display Segment | |
def display_segment(i): | |
ratio = waveform.size(1) / trellis.size(0) | |
word = word_segments[i] | |
x0 = int(ratio * word.start) | |
x1 = int(ratio * word.end) | |
print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec") | |
segment = waveform[:, x0:x1] | |
return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate) | |
# Part O: Audio generation for each segment | |
st.write('Abby Cadabby Transcript:') | |
st.write('Transcript') | |
st.write(IPython.display.Audio(SPEECH_FILE)) | |