TroglodyteDerivations's picture
Updated with: [removal] of display_segment
24cbc7d verified
import streamlit as st
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import dataclass
import string
st.image('abby_cadabby_.jpeg')
# Part A: Import torch and torchaudio
st.write(torch.__version__)
st.write(torchaudio.__version__)
device = 'cpu'
st.write(device)
st.set_option('deprecation.showPyplotGlobalUse', False)
# Part B: Load the audio file
SPEECH_FILE = 'abby_cadabby.wav'
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
st.write(SPEECH_FILE)
# Part C: torchaudio.pipelines | bundle.get_model | bundle.get_labels()
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
labels = bundle.get_labels()
# Inference mode
with torch.inference_mode():
# Load the audio file using torchaudio.load
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)
# Pass the waveform through the model
emissions, _ = model(waveform)
emissions = torch.log_softmax(emissions, dim=-1)
# Get the emissions for the first example
emission = emissions[0].cpu().detach()
# Print the labels
st.write('Labels are: ', labels)
st.write('Length of labels are: ', len(labels))
# Part D: Frame-wise class probability plot
def plot():
fig, ax = plt.subplots()
img = ax.imshow(emission.T)
ax.set_title("Frame-wise class probability")
ax.set_xlabel("Time")
ax.set_ylabel("Labels")
fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
fig.tight_layout()
return fig
st.pyplot(plot())
# Part E: Remove punctuation add | after each word. Also, convert into all UPPERCASE
def remove_punctuation(input_string):
# Make a translator object to remove all punctuation
translator = str.maketrans('', '', string.punctuation)
# Split the input string into words
words = input_string.split()
# Remove punctuation from each word, convert to uppercase, and join them with '|'
clean_words = [word.translate(translator).upper() + '|' for word in words]
clean_transcript = ''.join(clean_words).strip('|')
return clean_transcript
# Test the function
transcript = " Oh hi! It's me, Abby Cadabby. Do you want to watch me practice my magic? I am going to turn this"
clean_transcript = remove_punctuation(transcript)
st.write(clean_transcript)
# Part F: Populate Trellis
updated_clean_UPPERCASE_transcript = "OH|HI|ITS|ME|ABBY|CADABBY|DO|YOU|WANT|TO|WATCH|ME|PRACTICE|MY|MAGIC|I|AM|GOING|TO|TURN|THIS"
dictionary = {c: i for i, c in enumerate(labels)}
tokens = [dictionary[c] for c in updated_clean_UPPERCASE_transcript]
st.write(list(zip(updated_clean_UPPERCASE_transcript, tokens)))
def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)
trellis = torch.zeros((num_frame, num_tokens))
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
trellis[0, 1:] = -float("inf")
trellis[-num_tokens + 1 :, 0] = float("inf")
for t in range(num_frame - 1):
trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens[1:]]
)
return trellis
trellis = get_trellis(emission, tokens)
st.write('Trellis =', trellis)
# Part G: Labels and Time -Inf | +Inf
def n_inf_to_p_inf():
fig, ax = plt.subplots()
img = ax.imshow(trellis.T, origin="lower")
ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
# Shift the "+ Inf" annotation to the right by increasing the denominator
ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 2.4 , trellis.size(1) / 3))
fig.colorbar(img, ax=ax, shrink=0.25, location="bottom")
fig.tight_layout()
return fig
st.pyplot(n_inf_to_p_inf())
# Part H: Backtrack Trellis Emissions Tensor and Tokens
@dataclass
class Point:
token_index: int
time_index: int
score: float
def backtrack(trellis, emission, tokens, blank_id=0):
t, j = trellis.size(0) - 1, trellis.size(1) - 1
path = [Point(j, t, emission[t, blank_id].exp().item())]
while j > 0:
# Should not happen but just in case
assert t > 0
# 1. Figure out if the current position was stay or change
# Frame-wise score of stay vs change
p_stay = emission[t - 1, blank_id]
p_change = emission[t - 1, tokens[j]]
# Context-aware score for stay vs change
stayed = trellis[t - 1, j] + p_stay
changed = trellis[t - 1, j - 1] + p_change
# Update position
t -= 1
if changed > stayed:
j -= 1
# Store the path with frame-wise probability
prob = (p_change if changed > stayed else p_stay).exp().item()
path.append(Point(j, t, prob))
# Now j == 0, which means, it reached the SOS.
# Fill up the rest for the sake of visualization
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
path.append(Point(j, t - 1, prob))
t -= 1
return path[::-1]
path = backtrack(trellis, emission, tokens)
for p in path:
st.write('Token index, Time index and Score:')
st.write(p)
# Part I: Trellis with Path Visualization
def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
for _, p in enumerate(path):
trellis_with_path[p.time_index, p.token_index] = float("nan")
fig, ax = plt.subplots()
ax.imshow(trellis_with_path.T, origin="lower")
ax.set_title("The path found by backtracking")
ax.set_xlabel("Time")
ax.set_ylabel("Labels")
fig.tight_layout()
return fig
st.pyplot(plot_trellis_with_path(trellis, path))
# Part J: Merge Repeats | Segments
# Merge the labels
@dataclass
class Segment:
label: str
start: int
end: int
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}) : [{self.start:5d}, {self.end:5d})"
@property
def length(self):
return self.end - self.start
def merge_repeats(path):
i1, i2 = 0, 0
segments = []
while i1 < len(path):
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
i2 += 1
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
segments.append(
Segment(
updated_clean_UPPERCASE_transcript[path[i1].token_index],
path[i1].time_index,
path[i2 - 1].time_index + 1,
score,
)
)
i1 = i2
return segments
segments = merge_repeats(path)
for seg in segments:
st.write('Segments:')
st.write(seg)
# Part K: Trellis with Segments Visualization
def plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=44100):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(20, 18))
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_facecolor("lightgray")
ax1.set_xticks([])
ax1.set_yticks([])
for word in word_segments:
ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none")
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform
NFFT = 1024 # Adjust NFFT to be less than the length of the waveform
ratio = len(waveform) / sample_rate / trellis.size(0)
# Add a small offset to the waveform to avoid log of zero or negative numbers
waveform = waveform + 1e-10
ax2.specgram(waveform, Fs=sample_rate, NFFT=NFFT)
for word in word_segments:
x0 = ratio * word.start
x1 = ratio * word.end
ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/")
ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False)
ax2.set_xlabel("time [second]")
ax2.set_yticks([])
fig.tight_layout()
return fig
st.pyplot()
# Part L: Merge words | Segments
# Merge words
def merge_words(segments, separator="|"):
words = []
i1, i2 = 0, 0
while i1 < len(segments):
if i2 >= len(segments) or segments[i2].label == separator:
if i1 != i2:
segs = segments[i1:i2]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
i1 = i2 + 1
i2 = i1
else:
i2 += 1
return words
word_segments = merge_words(segments)
for word in word_segments:
st.write('Word Segments:')
st.write(word)
# Part M: Alignment Visualizations
def plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=44100):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(22, 22))
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_facecolor("lightgray")
ax1.set_xticks([])
ax1.set_yticks([])
for word in word_segments:
ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none")
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform
NFFT = 1024
#ratio = waveform.size(0) / sample_rate / trellis.size(0)
#ratio = len(waveform) / sample_rate / trellis.size(0)
ratio = len(waveform) / sample_rate / trellis.size(0) #-> populates both visualizations
ax2.specgram(waveform, Fs=sample_rate, NFFT=NFFT)
for word in word_segments:
x0 = ratio * word.start
x1 = ratio * word.end
ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/")
ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False)
ax2.set_xlabel("time [second]")
ax2.set_yticks([])
fig.tight_layout()
return fig
#plot_alignments(trellis, segments, word_segments, waveform, sample_rate)
st.pyplot(plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=16000))
# Part N: Audio generation
st.write('Abby Cadabby Transcript:')
# Display the audio in the Streamlit app
st.audio(SPEECH_FILE, format="audio/wav")
st.image('Abby_and_Prince.jpg')