import streamlit as st
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import dataclass
import string
import IPython
# Part A: Import torch and torchaudio
device = 'cpu'
# Part B: Load the audio file
SPEECH_FILE = 'abby_cadabby.wav'
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
# Part C: torchaudio.pipelines | bundle.get_model | bundle.get_labels()
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
labels = bundle.get_labels()
# Inference mode
with torch.inference_mode():
# Load the audio file using torchaudio.load
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform =
# Pass the waveform through the model
emissions, _ = model(waveform)
emissions = torch.log_softmax(emissions, dim=-1)
# Get the emissions for the first example
emission = emissions[0].cpu().detach()
# Print the labels
st.write('Labels are: ', labels)
st.write('Length of labels are: ', len(labels))
# Part D: Frame-wise class probability plot
def plot():
fig, ax = plt.subplots()
img = ax.imshow(emission.T)
ax.set_title("Frame-wise class probability")
fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
return fig
# Part E: Remove punctuation add | after each word. Also, convert into all UPPERCASE
def remove_punctuation(input_string):
# Make a translator object to remove all punctuation
translator = str.maketrans('', '', string.punctuation)
# Split the input string into words
words = input_string.split()
# Remove punctuation from each word, convert to uppercase, and join them with '|'
clean_words = ['|' + word.translate(translator).upper() + '|' for word in words]
clean_transcript = ''.join(clean_words).strip('|')
return clean_transcript
# Test the function
transcript = " Oh hi! It's me, Abby Cadabby. Do you want to watch me practice my magic? I am going to turn this"
clean_transcript = remove_punctuation(transcript)
# Part F: Populate Trellis
dictionary = {c: i for i, c in enumerate(labels)}
tokens = [dictionary[c] for c in updated_clean_UPPERCASE_transcript]
st.write(list(zip(updated_clean_UPPERCASE_transcript, tokens)))
def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)
trellis = torch.zeros((num_frame, num_tokens))
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
trellis[0, 1:] = -float("inf")
trellis[-num_tokens + 1 :, 0] = float("inf")
for t in range(num_frame - 1):
trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens[1:]]
return trellis
trellis = get_trellis(emission, tokens)
st.write('Trellis =', trellis)
# Part G: Labels and Time -Inf | +Inf
def n_inf_to_p_inf():
fig, ax = plt.subplots()
img = ax.imshow(trellis.T, origin="lower")
ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
# Shift the "+ Inf" annotation to the left by decreasing the x-coordinate value
ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 1.4, trellis.size(1) / 3))
fig.colorbar(img, ax=ax, shrink=0.25, location="bottom")
return fig
# Part H: Backtrack Trellis Emissions Tensor and Tokens
class Point:
token_index: int
time_index: int
score: float
def backtrack(trellis, emission, tokens, blank_id=0):
t, j = trellis.size(0) - 1, trellis.size(1) - 1
path = [Point(j, t, emission[t, blank_id].exp().item())]
while j > 0:
# Should not happen but just in case
assert t > 0
# 1. Figure out if the current position was stay or change
# Frame-wise score of stay vs change
p_stay = emission[t - 1, blank_id]
p_change = emission[t - 1, tokens[j]]
# Context-aware score for stay vs change
stayed = trellis[t - 1, j] + p_stay
changed = trellis[t - 1, j - 1] + p_change
# Update position
t -= 1
if changed > stayed:
j -= 1
# Store the path with frame-wise probability
prob = (p_change if changed > stayed else p_stay).exp().item()
path.append(Point(j, t, prob))
# Now j == 0, which means, it reached the SOS.
# Fill up the rest for the sake of visualization
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
path.append(Point(j, t - 1, prob))
t -= 1
return path[::-1]
path = backtrack(trellis, emission, tokens)
for p in path:
st.write('Token index, Time index and Score:')
# Part I: Trellis with Path Visualization
def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
for _, p in enumerate(path):
trellis_with_path[p.time_index, p.token_index] = float("nan")
plt.imshow(trellis_with_path.T, origin="lower")
plt.title("The path found by backtracking")
return plt
st.pyplot(plot_trellis_with_path(trellis, path))
# Part J: Merge Repeats | Segments
# Merge the labels
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}) : [{self.start:5d}, {self.end:5d})"
def length(self):
return self.end - self.start
def merge_repeats(path):
i1, i2 = 0, 0
segments = []
while i1 < len(path):
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
i2 += 1
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
path[i2 - 1].time_index + 1,
i1 = i2
return segments
segments = merge_repeats(path)
for seg in segments:
# Part K: Trellis with Segments Visualization
def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True, figsize=(15, 15))
ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
# Adjust the position of the annotations to spread them out
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.3), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 0.3), size="small")
ax2.set_title("Label probability with and without repetition")
xs, hs, ws = [], [], []
for seg in segments:
if seg.label != "|":
xs.append((seg.end + seg.start) / 2 + 0.4)
ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), rotation=0), hs, width=ws, color="gray", alpha=0.9, edgecolor="black")
xs, hs = [], []
for p in path:
label = transcript[p.token_index]
if label != "|":
xs.append(p.time_index + 1)
hs.append(p.score), hs, width=0.9, alpha=0.9)
ax2.axhline(0, color="black")
ax2.grid(True, axis="y")
ax2.set_ylim(-0.1, 1.1)
return fig
plot_trellis_with_segments(trellis, segments, clean_transcript)
st.pyplot(plot_trellis_with_segments(trellis, segments, clean_transcript))
# Part L: Merge words | Segments
# Merge words
def merge_words(segments, separator="|"):
words = []
i1, i2 = 0, 0
while i1 < len(segments):
if i2 >= len(segments) or segments[i2].label == separator:
if i1 != i2:
segs = segments[i1:i2]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
i1 = i2 + 1
i2 = i1
i2 += 1
return words
word_segments = merge_words(segments)
for word in word_segments:
st.write('Word Segments:')
# Part M: Alignment Visualizations
def plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=44100):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(20, 18))
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
for word in word_segments:
ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none")
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform
NFFT = 1024
#ratio = waveform.size(0) / sample_rate / trellis.size(0)
#ratio = len(waveform) / sample_rate / trellis.size(0)
ratio = len(waveform) / sample_rate / trellis.size(0) #-> populates both visualizations
ax2.specgram(waveform, Fs=sample_rate, NFFT=NFFT)
for word in word_segments:
x0 = ratio * word.start
x1 = ratio * word.end
ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/")
ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False)
ax2.set_xlabel("time [second]")
return fig
plot_alignments(trellis, segments, word_segments, waveform, sample_rate)
st.pyplot(plot_alignments(trellis, word_segments, waveform, sample_rate))
# Part N: Display Segment
def display_segment(i):
ratio = waveform.size(1) / trellis.size(0)
word = word_segments[i]
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec")
segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate)
# Part O: Audio generation for each segment
st.write('Abby Cadabby Transcript:')