import streamlit as st import torch import torchaudio import matplotlib.pyplot as plt import numpy as np from dataclasses import dataclass import string import IPython # Part A: Import torch and torchaudio st.write(torch.__version__) st.write(torchaudio.__version__) device = 'cpu' st.write(device) # Part B: Load the audio file SPEECH_FILE = 'abby_cadabby.wav' waveform, sample_rate = torchaudio.load(SPEECH_FILE) st.write(SPEECH_FILE) # Part C: torchaudio.pipelines | bundle.get_model | bundle.get_labels() bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H model = bundle.get_model().to(device) labels = bundle.get_labels() # Inference mode with torch.inference_mode(): # Load the audio file using torchaudio.load waveform, sample_rate = torchaudio.load(SPEECH_FILE) waveform = waveform.to(device) # Pass the waveform through the model emissions, _ = model(waveform) emissions = torch.log_softmax(emissions, dim=-1) # Get the emissions for the first example emission = emissions[0].cpu().detach() # Print the labels st.write('Labels are: ', labels) st.write('Length of labels are: ', len(labels)) # Part D: Frame-wise class probability plot def plot(): fig, ax = plt.subplots() img = ax.imshow(emission.T) ax.set_title("Frame-wise class probability") ax.set_xlabel("Time") ax.set_ylabel("Labels") fig.colorbar(img, ax=ax, shrink=0.6, location="bottom") fig.tight_layout() return fig st.pyplot(plot()) # Part E: Remove punctuation add | after each word. Also, convert into all UPPERCASE def remove_punctuation(input_string): # Make a translator object to remove all punctuation translator = str.maketrans('', '', string.punctuation) # Split the input string into words words = input_string.split() # Remove punctuation from each word, convert to uppercase, and join them with '|' clean_words = ['|' + word.translate(translator).upper() + '|' for word in words] clean_transcript = ''.join(clean_words).strip('|') return clean_transcript # Test the function transcript = " Oh hi! It's me, Abby Cadabby. Do you want to watch me practice my magic? I am going to turn this" clean_transcript = remove_punctuation(transcript) st.write(clean_transcript) # Part F: Populate Trellis updated_clean_UPPERCASE_transcript = "OH||HI||ITS||ME||ABBY||CADABBY||DO||YOU||WANT||TO||WATCH||ME||PRACTICE||MY||MAGIC||I||AM||GOING||TO||TURN||THIS" dictionary = {c: i for i, c in enumerate(labels)} tokens = [dictionary[c] for c in transcript] st.write(list(zip(transcript, tokens))) def get_trellis(emission, tokens, blank_id=0): num_frame = emission.size(0) num_tokens = len(tokens) trellis = torch.zeros((num_frame, num_tokens)) trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) trellis[0, 1:] = -float("inf") trellis[-num_tokens + 1 :, 0] = float("inf") for t in range(num_frame - 1): trellis[t + 1, 1:] = torch.maximum( # Score for staying at the same token trellis[t, 1:] + emission[t, blank_id], # Score for changing to the next token trellis[t, :-1] + emission[t, tokens[1:]] ) return trellis trellis = get_trellis(emission, tokens) st.write('Trellis =', trellis) # Part G: Labels and Time -Inf | +Inf def n_inf_to_p_inf(): fig, ax = plt.subplots() img = ax.imshow(trellis.T, origin="lower") ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5)) # Shift the "+ Inf" annotation to the left by decreasing the x-coordinate value ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 1.4, trellis.size(1) / 3)) fig.colorbar(img, ax=ax, shrink=0.25, location="bottom") fig.tight_layout() return fig st.pyplot(n_inf_to_p_inf()) # Part H: Backtrack Trellis Emissions Tensor and Tokens @dataclass class Point: token_index: int time_index: int score: float def backtrack(trellis, emission, tokens, blank_id=0): t, j = trellis.size(0) - 1, trellis.size(1) - 1 path = [Point(j, t, emission[t, blank_id].exp().item())] while j > 0: # Should not happen but just in case assert t > 0 # 1. Figure out if the current position was stay or change # Frame-wise score of stay vs change p_stay = emission[t - 1, blank_id] p_change = emission[t - 1, tokens[j]] # Context-aware score for stay vs change stayed = trellis[t - 1, j] + p_stay changed = trellis[t - 1, j - 1] + p_change # Update position t -= 1 if changed > stayed: j -= 1 # Store the path with frame-wise probability prob = (p_change if changed > stayed else p_stay).exp().item() path.append(Point(j, t, prob)) # Now j == 0, which means, it reached the SOS. # Fill up the rest for the sake of visualization while t > 0: prob = emission[t - 1, blank_id].exp().item() path.append(Point(j, t - 1, prob)) t -= 1 return path[::-1] path = backtrack(trellis, emission, tokens) for p in path: st.write('Token index, Time index and Score:') st.write(p) # Part I: Trellis with Path Visualization def plot_trellis_with_path(trellis, path): # To plot trellis with path, we take advantage of 'nan' value trellis_with_path = trellis.clone() for _, p in enumerate(path): trellis_with_path[p.time_index, p.token_index] = float("nan") plt.imshow(trellis_with_path.T, origin="lower") plt.title("The path found by backtracking") plt.tight_layout() return plt st.pyplot(plot_trellis_with_path(trellis, path)) # Part J: Merge Repeats | Segments # Merge the labels @dataclass class Segment: label: str start: int end: int score: float def __repr__(self): return f"{self.label}\t({self.score:4.2f}) : [{self.start:5d}, {self.end:5d})" @property def length(self): return self.end - self.start def merge_repeats(path): i1, i2 = 0, 0 segments = [] while i1 < len(path): while i2 < len(path) and path[i1].token_index == path[i2].token_index: i2 += 1 score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) segments.append( Segment( transcript[path[i1].token_index], path[i1].time_index, path[i2 - 1].time_index + 1, score, ) ) i1 = i2 return segments segments = merge_repeats(path) for seg in segments: st.write('Segments:') st.write(seg) # Part K: Trellis with Segments Visualization def plot_trellis_with_segments(trellis, segments, transcript): # To plot trellis with path, we take advantage of 'nan' value trellis_with_path = trellis.clone() for i, seg in enumerate(segments): if seg.label != "|": trellis_with_path[seg.start : seg.end, i] = float("nan") fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True, figsize=(15, 15)) ax1.set_title("Path, label and probability for each label") ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") # Adjust the position of the annotations to spread them out for i, seg in enumerate(segments): if seg.label != "|": ax1.annotate(seg.label, (seg.start, i - 0.3), size="small") ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 0.3), size="small") ax2.set_title("Label probability with and without repetition") xs, hs, ws = [], [], [] for seg in segments: if seg.label != "|": xs.append((seg.end + seg.start) / 2 + 0.4) hs.append(seg.score) ws.append(seg.end - seg.start) ax2.annotate(seg.label, (seg.start + 0.8, -0.07), rotation=0) ax2.bar(xs, hs, width=ws, color="gray", alpha=0.9, edgecolor="black") xs, hs = [], [] for p in path: label = transcript[p.token_index] if label != "|": xs.append(p.time_index + 1) hs.append(p.score) ax2.bar(xs, hs, width=0.9, alpha=0.9) ax2.axhline(0, color="black") ax2.grid(True, axis="y") ax2.set_ylim(-0.1, 1.1) fig.tight_layout() return fig plot_trellis_with_segments(trellis, segments, clean_transcript) st.pyplot(plot_trellis_with_segments(trellis, segments, clean_transcript)) # Part L: Merge words | Segments # Merge words def merge_words(segments, separator="|"): words = [] i1, i2 = 0, 0 while i1 < len(segments): if i2 >= len(segments) or segments[i2].label == separator: if i1 != i2: segs = segments[i1:i2] word = "".join([seg.label for seg in segs]) score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) i1 = i2 + 1 i2 = i1 else: i2 += 1 return words word_segments = merge_words(segments) for word in word_segments: st.write('Word Segments:') st.write(word) # Part M: Alignment Visualizations def plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=44100): trellis_with_path = trellis.clone() for i, seg in enumerate(segments): if seg.label != "|": trellis_with_path[seg.start : seg.end, i] = float("nan") fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(20, 18)) ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") ax1.set_facecolor("lightgray") ax1.set_xticks([]) ax1.set_yticks([]) for word in word_segments: ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none") for i, seg in enumerate(segments): if seg.label != "|": ax1.annotate(seg.label, (seg.start, i - 0.7), size="small") ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small") # The original waveform NFFT = 1024 #ratio = waveform.size(0) / sample_rate / trellis.size(0) #ratio = len(waveform) / sample_rate / trellis.size(0) ratio = len(waveform) / sample_rate / trellis.size(0) #-> populates both visualizations ax2.specgram(waveform, Fs=sample_rate, NFFT=NFFT) for word in word_segments: x0 = ratio * word.start x1 = ratio * word.end ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/") ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False) for seg in segments: if seg.label != "|": ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False) ax2.set_xlabel("time [second]") ax2.set_yticks([]) fig.tight_layout() return fig plot_alignments(trellis, segments, word_segments, waveform, sample_rate) st.pyplot(plot_alignments(trellis, word_segments, waveform, sample_rate)) # Part N: Display Segment def display_segment(i): ratio = waveform.size(1) / trellis.size(0) word = word_segments[i] x0 = int(ratio * word.start) x1 = int(ratio * word.end) print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec") segment = waveform[:, x0:x1] return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate) # Part O: Audio generation for each segment st.write('Abby Cadabby Transcript:') st.write('Transcript') st.write(IPython.display.Audio(SPEECH_FILE))