import streamlit as st import torch import torchaudio import matplotlib.pyplot as plt import numpy as np from dataclasses import dataclass import string st.image('abby_cadabby_.jpeg') # Part A: Import torch and torchaudio st.write(torch.__version__) st.write(torchaudio.__version__) device = 'cpu' st.write(device) st.set_option('deprecation.showPyplotGlobalUse', False) # Part B: Load the audio file SPEECH_FILE = 'abby_cadabby.wav' waveform, sample_rate = torchaudio.load(SPEECH_FILE) st.write(SPEECH_FILE) # Part C: torchaudio.pipelines | bundle.get_model | bundle.get_labels() bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H model = bundle.get_model().to(device) labels = bundle.get_labels() # Inference mode with torch.inference_mode(): # Load the audio file using torchaudio.load waveform, sample_rate = torchaudio.load(SPEECH_FILE) waveform = waveform.to(device) # Pass the waveform through the model emissions, _ = model(waveform) emissions = torch.log_softmax(emissions, dim=-1) # Get the emissions for the first example emission = emissions[0].cpu().detach() # Print the labels st.write('Labels are: ', labels) st.write('Length of labels are: ', len(labels)) # Part D: Frame-wise class probability plot def plot(): fig, ax = plt.subplots() img = ax.imshow(emission.T) ax.set_title("Frame-wise class probability") ax.set_xlabel("Time") ax.set_ylabel("Labels") fig.colorbar(img, ax=ax, shrink=0.6, location="bottom") fig.tight_layout() return fig st.pyplot(plot()) # Part E: Remove punctuation add | after each word. Also, convert into all UPPERCASE def remove_punctuation(input_string): # Make a translator object to remove all punctuation translator = str.maketrans('', '', string.punctuation) # Split the input string into words words = input_string.split() # Remove punctuation from each word, convert to uppercase, and join them with '|' clean_words = [word.translate(translator).upper() + '|' for word in words] clean_transcript = ''.join(clean_words).strip('|') return clean_transcript # Test the function transcript = " Oh hi! It's me, Abby Cadabby. Do you want to watch me practice my magic? I am going to turn this" clean_transcript = remove_punctuation(transcript) st.write(clean_transcript) # Part F: Populate Trellis updated_clean_UPPERCASE_transcript = "OH|HI|ITS|ME|ABBY|CADABBY|DO|YOU|WANT|TO|WATCH|ME|PRACTICE|MY|MAGIC|I|AM|GOING|TO|TURN|THIS" dictionary = {c: i for i, c in enumerate(labels)} tokens = [dictionary[c] for c in updated_clean_UPPERCASE_transcript] st.write(list(zip(updated_clean_UPPERCASE_transcript, tokens))) def get_trellis(emission, tokens, blank_id=0): num_frame = emission.size(0) num_tokens = len(tokens) trellis = torch.zeros((num_frame, num_tokens)) trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) trellis[0, 1:] = -float("inf") trellis[-num_tokens + 1 :, 0] = float("inf") for t in range(num_frame - 1): trellis[t + 1, 1:] = torch.maximum( # Score for staying at the same token trellis[t, 1:] + emission[t, blank_id], # Score for changing to the next token trellis[t, :-1] + emission[t, tokens[1:]] ) return trellis trellis = get_trellis(emission, tokens) st.write('Trellis =', trellis) # Part G: Labels and Time -Inf | +Inf def n_inf_to_p_inf(): fig, ax = plt.subplots() img = ax.imshow(trellis.T, origin="lower") ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5)) # Shift the "+ Inf" annotation to the right by increasing the denominator ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 2.4 , trellis.size(1) / 3)) fig.colorbar(img, ax=ax, shrink=0.25, location="bottom") fig.tight_layout() return fig st.pyplot(n_inf_to_p_inf()) # Part H: Backtrack Trellis Emissions Tensor and Tokens @dataclass class Point: token_index: int time_index: int score: float def backtrack(trellis, emission, tokens, blank_id=0): t, j = trellis.size(0) - 1, trellis.size(1) - 1 path = [Point(j, t, emission[t, blank_id].exp().item())] while j > 0: # Should not happen but just in case assert t > 0 # 1. Figure out if the current position was stay or change # Frame-wise score of stay vs change p_stay = emission[t - 1, blank_id] p_change = emission[t - 1, tokens[j]] # Context-aware score for stay vs change stayed = trellis[t - 1, j] + p_stay changed = trellis[t - 1, j - 1] + p_change # Update position t -= 1 if changed > stayed: j -= 1 # Store the path with frame-wise probability prob = (p_change if changed > stayed else p_stay).exp().item() path.append(Point(j, t, prob)) # Now j == 0, which means, it reached the SOS. # Fill up the rest for the sake of visualization while t > 0: prob = emission[t - 1, blank_id].exp().item() path.append(Point(j, t - 1, prob)) t -= 1 return path[::-1] path = backtrack(trellis, emission, tokens) for p in path: st.write('Token index, Time index and Score:') st.write(p) # Part I: Trellis with Path Visualization def plot_trellis_with_path(trellis, path): # To plot trellis with path, we take advantage of 'nan' value trellis_with_path = trellis.clone() for _, p in enumerate(path): trellis_with_path[p.time_index, p.token_index] = float("nan") fig, ax = plt.subplots() ax.imshow(trellis_with_path.T, origin="lower") ax.set_title("The path found by backtracking") ax.set_xlabel("Time") ax.set_ylabel("Labels") fig.tight_layout() return fig st.pyplot(plot_trellis_with_path(trellis, path)) # Part J: Merge Repeats | Segments # Merge the labels @dataclass class Segment: label: str start: int end: int score: float def __repr__(self): return f"{self.label}\t({self.score:4.2f}) : [{self.start:5d}, {self.end:5d})" @property def length(self): return self.end - self.start def merge_repeats(path): i1, i2 = 0, 0 segments = [] while i1 < len(path): while i2 < len(path) and path[i1].token_index == path[i2].token_index: i2 += 1 score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) segments.append( Segment( updated_clean_UPPERCASE_transcript[path[i1].token_index], path[i1].time_index, path[i2 - 1].time_index + 1, score, ) ) i1 = i2 return segments segments = merge_repeats(path) for seg in segments: st.write('Segments:') st.write(seg) # Part K: Trellis with Segments Visualization def plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=44100): trellis_with_path = trellis.clone() for i, seg in enumerate(segments): if seg.label != "|": trellis_with_path[seg.start : seg.end, i] = float("nan") fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(20, 18)) ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") ax1.set_facecolor("lightgray") ax1.set_xticks([]) ax1.set_yticks([]) for word in word_segments: ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none") for i, seg in enumerate(segments): if seg.label != "|": ax1.annotate(seg.label, (seg.start, i - 0.7), size="small") ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small") # The original waveform NFFT = 1024 # Adjust NFFT to be less than the length of the waveform ratio = len(waveform) / sample_rate / trellis.size(0) # Add a small offset to the waveform to avoid log of zero or negative numbers waveform = waveform + 1e-10 ax2.specgram(waveform, Fs=sample_rate, NFFT=NFFT) for word in word_segments: x0 = ratio * word.start x1 = ratio * word.end ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/") ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False) for seg in segments: if seg.label != "|": ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False) ax2.set_xlabel("time [second]") ax2.set_yticks([]) fig.tight_layout() return fig st.pyplot() # Part L: Merge words | Segments # Merge words def merge_words(segments, separator="|"): words = [] i1, i2 = 0, 0 while i1 < len(segments): if i2 >= len(segments) or segments[i2].label == separator: if i1 != i2: segs = segments[i1:i2] word = "".join([seg.label for seg in segs]) score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) i1 = i2 + 1 i2 = i1 else: i2 += 1 return words word_segments = merge_words(segments) for word in word_segments: st.write('Word Segments:') st.write(word) # Part M: Alignment Visualizations def plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=44100): trellis_with_path = trellis.clone() for i, seg in enumerate(segments): if seg.label != "|": trellis_with_path[seg.start : seg.end, i] = float("nan") fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(22, 22)) ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") ax1.set_facecolor("lightgray") ax1.set_xticks([]) ax1.set_yticks([]) for word in word_segments: ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none") for i, seg in enumerate(segments): if seg.label != "|": ax1.annotate(seg.label, (seg.start, i - 0.7), size="small") ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small") # The original waveform NFFT = 1024 #ratio = waveform.size(0) / sample_rate / trellis.size(0) #ratio = len(waveform) / sample_rate / trellis.size(0) ratio = len(waveform) / sample_rate / trellis.size(0) #-> populates both visualizations ax2.specgram(waveform, Fs=sample_rate, NFFT=NFFT) for word in word_segments: x0 = ratio * word.start x1 = ratio * word.end ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/") ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False) for seg in segments: if seg.label != "|": ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False) ax2.set_xlabel("time [second]") ax2.set_yticks([]) fig.tight_layout() return fig #plot_alignments(trellis, segments, word_segments, waveform, sample_rate) st.pyplot(plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=16000)) # Part N: Audio generation st.write('Abby Cadabby Transcript:') # Display the audio in the Streamlit app st.audio(SPEECH_FILE, format="audio/wav") st.image('Abby_and_Prince.jpg')