TroglodyteDerivations commited on
Commit
c5c0498
·
verified ·
1 Parent(s): 7768e3f

Updated lines 218-259

Browse files
Files changed (1) hide show
  1. app.py +33 -33
app.py CHANGED
@@ -173,7 +173,7 @@ def plot_trellis_with_path(trellis, path):
173
  plt.tight_layout()
174
  return plt
175
 
176
- st.pyplot(plot_trellis_with_path(trellis, path))
177
 
178
  # Part J: Merge Repeats | Segments
179
  # Merge the labels
@@ -215,50 +215,50 @@ for seg in segments:
215
  st.write(seg)
216
 
217
  # Part K: Trellis with Segments Visualization
218
- def plot_trellis_with_segments(trellis, segments, transcript):
219
- # To plot trellis with path, we take advantage of 'nan' value
220
  trellis_with_path = trellis.clone()
221
  for i, seg in enumerate(segments):
222
  if seg.label != "|":
223
  trellis_with_path[seg.start : seg.end, i] = float("nan")
224
-
225
- fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True, figsize=(15, 15))
226
- ax1.set_title("Path, label and probability for each label")
227
  ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
228
-
229
- # Adjust the position of the annotations to spread them out
 
 
 
 
 
230
  for i, seg in enumerate(segments):
231
  if seg.label != "|":
232
- ax1.annotate(seg.label, (seg.start, i - 0.3), size="small")
233
- ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 0.3), size="small")
234
-
235
- ax2.set_title("Label probability with and without repetition")
236
- xs, hs, ws = [], [], []
 
 
 
 
 
 
 
 
 
 
 
 
237
  for seg in segments:
238
  if seg.label != "|":
239
- xs.append((seg.end + seg.start) / 2 + 0.4)
240
- hs.append(seg.score)
241
- ws.append(seg.end - seg.start)
242
- ax2.annotate(seg.label, (seg.start + 0.8, -0.07), rotation=0)
243
- ax2.bar(xs, hs, width=ws, color="gray", alpha=0.9, edgecolor="black")
244
-
245
- xs, hs = [], []
246
- for p in path:
247
- label = transcript[p.token_index]
248
- if label != "|":
249
- xs.append(p.time_index + 1)
250
- hs.append(p.score)
251
-
252
- ax2.bar(xs, hs, width=0.9, alpha=0.9)
253
- ax2.axhline(0, color="black")
254
- ax2.grid(True, axis="y")
255
- ax2.set_ylim(-0.1, 1.1)
256
  fig.tight_layout()
257
  return fig
258
 
259
-
260
- plot_trellis_with_segments(trellis, segments, updated_clean_UPPERCASE_transcript)
261
- st.pyplot(plot_trellis_with_segments(trellis, segments, updated_clean_UPPERCASE_transcript))
262
 
263
  # Part L: Merge words | Segments
264
  # Merge words
 
173
  plt.tight_layout()
174
  return plt
175
 
176
+ st.pyplot(plt)
177
 
178
  # Part J: Merge Repeats | Segments
179
  # Merge the labels
 
215
  st.write(seg)
216
 
217
  # Part K: Trellis with Segments Visualization
218
+ def plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=44100):
 
219
  trellis_with_path = trellis.clone()
220
  for i, seg in enumerate(segments):
221
  if seg.label != "|":
222
  trellis_with_path[seg.start : seg.end, i] = float("nan")
223
+
224
+ fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(20, 18))
225
+
226
  ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
227
+ ax1.set_facecolor("lightgray")
228
+ ax1.set_xticks([])
229
+ ax1.set_yticks([])
230
+
231
+ for word in word_segments:
232
+ ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none")
233
+
234
  for i, seg in enumerate(segments):
235
  if seg.label != "|":
236
+ ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
237
+ ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
238
+
239
+ # The original waveform
240
+ NFFT = 1024 # Adjust NFFT to be less than the length of the waveform
241
+ ratio = len(waveform) / sample_rate / trellis.size(0)
242
+
243
+ # Add a small offset to the waveform to avoid log of zero or negative numbers
244
+ waveform = waveform + 1e-10
245
+
246
+ ax2.specgram(waveform, Fs=sample_rate, NFFT=NFFT)
247
+ for word in word_segments:
248
+ x0 = ratio * word.start
249
+ x1 = ratio * word.end
250
+ ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/")
251
+ ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False)
252
+
253
  for seg in segments:
254
  if seg.label != "|":
255
+ ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False)
256
+ ax2.set_xlabel("time [second]")
257
+ ax2.set_yticks([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  fig.tight_layout()
259
  return fig
260
 
261
+ st.pyplot(fig)
 
 
262
 
263
  # Part L: Merge words | Segments
264
  # Merge words