Spaces:
Runtime error
Runtime error
from copy import deepcopy | |
from os.path import basename, splitext | |
import librosa | |
import numpy as np | |
import pandas as pd | |
import soundfile as sf | |
import torch | |
import torchaudio | |
from scipy.ndimage import gaussian_filter | |
from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering | |
from sklearn.metrics import pairwise_distances | |
from speechbrain.pretrained import EncoderClassifier | |
def similarity_matrix(embeds, metric="cosine"): | |
return pairwise_distances(embeds, metric=metric) | |
def cluster_AHC(embeds, n_clusters=None, threshold=None, metric="cosine", **kwargs): | |
""" | |
Cluster embeds using Agglomerative Hierarchical Clustering | |
""" | |
if n_clusters is None: | |
assert threshold, "If num_clusters is not defined, threshold must be defined" | |
S = similarity_matrix(embeds, metric=metric) | |
if n_clusters is None: | |
cluster_model = AgglomerativeClustering( | |
n_clusters=None, | |
affinity="precomputed", | |
linkage="average", | |
compute_full_tree=True, | |
distance_threshold=threshold, | |
) | |
return cluster_model.fit_predict(S) | |
else: | |
cluster_model = AgglomerativeClustering( | |
n_clusters=n_clusters, affinity="precomputed", linkage="average" | |
) | |
return cluster_model.fit_predict(S) | |
########################################## | |
# Spectral clustering | |
# A lot of these methods are lifted from | |
# https://github.com/wq2012/SpectralCluster | |
########################################## | |
def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwargs): | |
""" | |
Cluster embeds using Spectral Clustering | |
""" | |
if n_clusters is None: | |
assert threshold, "If num_clusters is not defined, threshold must be defined" | |
S = compute_affinity_matrix(embeds) | |
if enhance_sim: | |
S = sim_enhancement(S) | |
if n_clusters is None: | |
(eigenvalues, eigenvectors) = compute_sorted_eigenvectors(S) | |
# Get number of clusters. | |
k = compute_number_of_clusters(eigenvalues, 100, threshold) | |
# Get spectral embeddings. | |
spectral_embeddings = eigenvectors[:, :k] | |
# Run K-Means++ on spectral embeddings. | |
# Note: The correct way should be using a K-Means implementation | |
# that supports customized distance measure such as cosine distance. | |
# This implemention from scikit-learn does NOT, which is inconsistent | |
# with the paper. | |
kmeans_clusterer = KMeans( | |
n_clusters=k, init="k-means++", max_iter=300, random_state=0 | |
) | |
labels = kmeans_clusterer.fit_predict(spectral_embeddings) | |
return labels | |
else: | |
cluster_model = SpectralClustering( | |
n_clusters=n_clusters, affinity="precomputed" | |
) | |
return cluster_model.fit_predict(S) | |
def diagonal_fill(A): | |
""" | |
Sets the diagonal elemnts of the matrix to the max of each row | |
""" | |
np.fill_diagonal(A, 0.0) | |
A[np.diag_indices(A.shape[0])] = np.max(A, axis=1) | |
return A | |
def gaussian_blur(A, sigma=1.0): | |
""" | |
Does a gaussian blur on the affinity matrix | |
""" | |
return gaussian_filter(A, sigma=sigma) | |
def row_threshold_mult(A, p=0.95, mult=0.01): | |
""" | |
For each row multiply elements smaller than the row's p'th percentile by mult | |
""" | |
percentiles = np.percentile(A, p * 100, axis=1) | |
mask = A < percentiles[:, np.newaxis] | |
A = (mask * mult * A) + (~mask * A) | |
return A | |
def symmetrization(A): | |
""" | |
Symmeterization: Y_{i,j} = max(S_{ij}, S_{ji}) | |
""" | |
return np.maximum(A, A.T) | |
def diffusion(A): | |
""" | |
Diffusion: Y <- YY^T | |
""" | |
return np.dot(A, A.T) | |
def row_max_norm(A): | |
""" | |
Row-wise max normalization: S_{ij} = Y_{ij} / max_k(Y_{ik}) | |
""" | |
maxes = np.amax(A, axis=1) | |
return A / maxes | |
def sim_enhancement(A): | |
func_order = [ | |
diagonal_fill, | |
gaussian_blur, | |
row_threshold_mult, | |
symmetrization, | |
diffusion, | |
row_max_norm, | |
] | |
for f in func_order: | |
A = f(A) | |
return A | |
def compute_affinity_matrix(X): | |
"""Compute the affinity matrix from data. | |
Note that the range of affinity is [0,1]. | |
Args: | |
X: numpy array of shape (n_samples, n_features) | |
Returns: | |
affinity: numpy array of shape (n_samples, n_samples) | |
""" | |
# Normalize the data. | |
l2_norms = np.linalg.norm(X, axis=1) | |
X_normalized = X / l2_norms[:, None] | |
# Compute cosine similarities. Range is [-1,1]. | |
cosine_similarities = np.matmul(X_normalized, np.transpose(X_normalized)) | |
# Compute the affinity. Range is [0,1]. | |
# Note that this step is not mentioned in the paper! | |
affinity = (cosine_similarities + 1.0) / 2.0 | |
return affinity | |
def compute_sorted_eigenvectors(A): | |
"""Sort eigenvectors by the real part of eigenvalues. | |
Args: | |
A: the matrix to perform eigen analysis with shape (M, M) | |
Returns: | |
w: sorted eigenvalues of shape (M,) | |
v: sorted eigenvectors, where v[;, i] corresponds to ith largest | |
eigenvalue | |
""" | |
# Eigen decomposition. | |
eigenvalues, eigenvectors = np.linalg.eig(A) | |
eigenvalues = eigenvalues.real | |
eigenvectors = eigenvectors.real | |
# Sort from largest to smallest. | |
index_array = np.argsort(-eigenvalues) | |
# Re-order. | |
w = eigenvalues[index_array] | |
v = eigenvectors[:, index_array] | |
return w, v | |
def compute_number_of_clusters(eigenvalues, max_clusters=None, stop_eigenvalue=1e-2): | |
"""Compute number of clusters using EigenGap principle. | |
Args: | |
eigenvalues: sorted eigenvalues of the affinity matrix | |
max_clusters: max number of clusters allowed | |
stop_eigenvalue: we do not look at eigen values smaller than this | |
Returns: | |
number of clusters as an integer | |
""" | |
max_delta = 0 | |
max_delta_index = 0 | |
range_end = len(eigenvalues) | |
if max_clusters and max_clusters + 1 < range_end: | |
range_end = max_clusters + 1 | |
for i in range(1, range_end): | |
if eigenvalues[i - 1] < stop_eigenvalue: | |
break | |
delta = eigenvalues[i - 1] / eigenvalues[i] | |
if delta > max_delta: | |
max_delta = delta | |
max_delta_index = i | |
return max_delta_index | |
class Diarizer: | |
def __init__( | |
self, device='cuda:0', embed_model="xvec", cluster_method="sc", window=1.5, period=0.75 | |
): | |
self.device = device | |
assert embed_model in [ | |
"xvec", | |
"ecapa", | |
], "Only xvec and ecapa are supported options" | |
assert cluster_method in [ | |
"ahc", | |
"sc", | |
], "Only ahc and sc in the supported clustering options" | |
if cluster_method == "ahc": | |
self.cluster = cluster_AHC | |
if cluster_method == "sc": | |
self.cluster = cluster_SC | |
self.vad_model, self.get_speech_ts = self.setup_VAD() | |
self.run_opts = ({"device": self.device}) | |
if embed_model == "ecapa": | |
self.embed_model = EncoderClassifier.from_hparams( | |
source="speechbrain/spkrec-ecapa-voxceleb", | |
savedir="pretrained_models/spkrec-ecapa-voxceleb", | |
run_opts=self.run_opts, | |
) | |
self.window = window | |
self.period = period | |
def setup_VAD(self): | |
model, utils = torch.hub.load( | |
repo_or_dir="snakers4/silero-vad", model="silero_vad" | |
) | |
# force_reload=True) | |
get_speech_ts = utils[0] | |
return model, get_speech_ts | |
def vad(self, signal): | |
""" | |
Runs the VAD model on the signal | |
""" | |
return self.get_speech_ts(signal.to(self.device), self.vad_model.to(self.device)) | |
def windowed_embeds(self, signal, fs, window=1.5, period=0.75): | |
""" | |
Calculates embeddings for windows across the signal | |
window: length of the window, in seconds | |
period: jump of the window, in seconds | |
returns: embeddings, segment info | |
""" | |
len_window = int(window * fs) | |
len_period = int(period * fs) | |
len_signal = signal.shape[1] | |
# Get the windowed segments | |
segments = [] | |
start = 0 | |
while start + len_window < len_signal: | |
segments.append([start, start + len_window]) | |
start += len_period | |
segments.append([start, len_signal - 1]) | |
embeds = [] | |
with torch.no_grad(): | |
for i, j in segments: | |
signal_seg = signal[:, i:j] | |
seg_embed = self.embed_model.encode_batch(signal_seg) | |
embeds.append(seg_embed.squeeze(0).squeeze(0).cpu().numpy()) | |
embeds = np.array(embeds) | |
return embeds, np.array(segments) | |
def recording_embeds(self, signal, fs, speech_ts): | |
""" | |
Takes signal and VAD output (speech_ts) and produces windowed embeddings | |
returns: embeddings, segment info | |
""" | |
all_embeds = [] | |
all_segments = [] | |
for utt in speech_ts: | |
start = utt["start"] | |
end = utt["end"] | |
utt_signal = signal[:, start:end] | |
utt_embeds, utt_segments = self.windowed_embeds( | |
utt_signal, fs, self.window, self.period | |
) | |
all_embeds.append(utt_embeds) | |
all_segments.append(utt_segments + start) | |
all_embeds = np.concatenate(all_embeds, axis=0) | |
all_segments = np.concatenate(all_segments, axis=0) | |
return all_embeds, all_segments | |
def join_segments(cluster_labels, segments, tolerance=5): | |
""" | |
Joins up same speaker segments, resolves overlap conflicts | |
Uses the midpoint for overlap conflicts | |
tolerance allows for very minimally separated segments to be combined | |
(in samples) | |
""" | |
assert len(cluster_labels) == len(segments) | |
new_segments = [ | |
{"start": segments[0][0], "end": segments[0][1], "label": cluster_labels[0]} | |
] | |
for l, seg in zip(cluster_labels[1:], segments[1:]): | |
start = seg[0] | |
end = seg[1] | |
protoseg = {"start": seg[0], "end": seg[1], "label": l} | |
if start <= new_segments[-1]["end"]: | |
# If segments overlap | |
if l == new_segments[-1]["label"]: | |
# If overlapping segment has same label | |
new_segments[-1]["end"] = end | |
else: | |
# If overlapping segment has diff label | |
# Resolve by setting new start to midpoint | |
# And setting last segment end to midpoint | |
overlap = new_segments[-1]["end"] - start | |
midpoint = start + overlap // 2 | |
new_segments[-1]["end"] = midpoint | |
protoseg["start"] = midpoint | |
new_segments.append(protoseg) | |
else: | |
# If there's no overlap just append | |
new_segments.append(protoseg) | |
return new_segments | |
def make_output_seconds(cleaned_segments, fs): | |
""" | |
Convert cleaned segments to readable format in seconds | |
""" | |
for seg in cleaned_segments: | |
seg["start_sample"] = seg["start"] | |
seg["end_sample"] = seg["end"] | |
seg["start"] = seg["start"] / fs | |
seg["end"] = seg["end"] / fs | |
return cleaned_segments | |
def diarize( | |
self, | |
wav_file, | |
num_speakers=2, | |
threshold=None, | |
silence_tolerance=0.2, | |
enhance_sim=True, | |
extra_info=False, | |
outfile=None, | |
): | |
""" | |
Diarize a 16khz mono wav file, produces list of segments | |
Inputs: | |
wav_file (path): Path to input audio file | |
num_speakers (int) or NoneType: Number of speakers to cluster to | |
threshold (float) or NoneType: Threshold to cluster to if | |
num_speakers is not defined | |
silence_tolerance (float): Same speaker segments which are close enough together | |
by silence_tolerance will be joined into a single segment | |
enhance_sim (bool): Whether or not to perform affinity matrix enhancement | |
during spectral clustering | |
If self.cluster_method is 'ahc' this option does nothing. | |
extra_info (bool): Whether or not to return the embeddings and raw segments | |
in addition to segments | |
outfile (path): If specified will output an RTTM file | |
Outputs: | |
If extra_info is False: | |
segments (list): List of dicts with segment information | |
{ | |
'start': Start time of segment in seconds, | |
'start_sample': Starting index of segment, | |
'end': End time of segment in seconds, | |
'end_sample' Ending index of segment, | |
'label': Cluster label of segment | |
} | |
If extra_info is True: | |
dict: { 'segments': segments (list): List of dicts with segment information | |
{ | |
'start': Start time of segment in seconds, | |
'start_sample': Starting index of segment, | |
'end': End time of segment in seconds, | |
'end_sample' Ending index of segment, | |
'label': Cluster label of segment | |
}, | |
'embeds': embeddings (np.array): Array of embeddings, each row corresponds to a segment, | |
'segments': segments (list): indexes for start and end frame for each embed in embeds, | |
'cluster_labels': cluster_labels (list): cluster label for each embed in embeds | |
} | |
Uses AHC/SC to cluster | |
""" | |
signal, fs = torchaudio.load(wav_file) | |
if len(signal) == 2: | |
signal = signal[:1, :] | |
if fs != 16000: | |
signal = torchaudio.functional.resample(signal, fs, 16000) | |
fs = 16000 | |
speech_ts = self.vad(signal[0]) | |
if len(speech_ts) >= 1: | |
embeds, segments = self.recording_embeds(signal, fs, speech_ts) | |
if len(embeds) > 1: | |
cluster_labels = self.cluster( | |
embeds, | |
n_clusters=num_speakers, | |
threshold=threshold, | |
enhance_sim=enhance_sim, | |
) | |
else: | |
cluster_labels = np.zeros(len(embeds), dtype=int) | |
cleaned_segments = self.join_segments(cluster_labels, segments) | |
cleaned_segments = self.make_output_seconds(cleaned_segments, fs) | |
cleaned_segments = self.join_samespeaker_segments( | |
cleaned_segments, silence_tolerance=silence_tolerance | |
) | |
if outfile: | |
self.rttm_output(cleaned_segments, splitext(basename(wav_file))[0], outfile=outfile) | |
if not extra_info: | |
return cleaned_segments | |
else: | |
return {"clean_segments": cleaned_segments, | |
"embeds": embeds, | |
"segments": segments, | |
"cluster_labels": cluster_labels} | |
else: | |
print("Couldn't find any speech during VAD") | |
return {} | |
def rttm_output(segments, recname, outfile=None): | |
assert outfile, "Please specify an outfile" | |
rttm_line = "SPEAKER {} 0 {} {} <NA> <NA> {} <NA> <NA>\n" | |
with open(outfile, "w") as fp: | |
for seg in segments: | |
start = seg["start"] | |
offset = seg["end"] - seg["start"] | |
label = seg["label"] | |
line = rttm_line.format(recname, start, offset, label) | |
fp.write(line) | |
def join_samespeaker_segments(segments, silence_tolerance=0.5): | |
""" | |
Join up segments that belong to the same speaker, | |
even if there is a duration of silence in between them. | |
If the silence is greater than silence_tolerance, does not join up | |
""" | |
new_segments = [segments[0]] | |
for seg in segments[1:]: | |
if seg["label"] == new_segments[-1]["label"]: | |
if new_segments[-1]["end"] + silence_tolerance >= seg["start"]: | |
new_segments[-1]["end"] = seg["end"] | |
new_segments[-1]["end_sample"] = seg["end_sample"] | |
else: | |
new_segments.append(seg) | |
else: | |
new_segments.append(seg) | |
return new_segments | |
def match_diarization_to_transcript(segments, text_segments): | |
""" | |
Match the output of .diarize to word segments | |
""" | |
text_starts, text_ends, text_segs = [], [], [] | |
for s in text_segments: | |
text_starts.append(s["start"]) | |
text_ends.append(s["end"]) | |
text_segs.append(s["text"]) | |
text_starts = np.array(text_starts) | |
text_ends = np.array(text_ends) | |
text_segs = np.array(text_segs) | |
# Get the earliest start from either diar output or asr output | |
earliest_start = np.min([text_starts[0], segments[0]["start"]]) | |
worded_segments = segments.copy() | |
worded_segments[0]["start"] = earliest_start | |
cutoffs = [] | |
for seg in worded_segments: | |
end_idx = np.searchsorted(text_ends, seg["end"], side="left") - 1 | |
cutoffs.append(end_idx) | |
indexes = [[0, cutoffs[0]]] | |
for c in cutoffs[1:]: | |
indexes.append([indexes[-1][-1], c]) | |
indexes[-1][-1] = len(text_segs) | |
final_segments = [] | |
for i, seg in enumerate(worded_segments): | |
s_idx, e_idx = indexes[i] | |
words = text_segs[s_idx:e_idx] | |
newseg = deepcopy(seg) | |
newseg["words"] = " ".join(words) | |
final_segments.append(newseg) | |
return final_segments | |
def match_diarization_to_transcript_ctm(self, segments, ctm_file): | |
""" | |
Match the output of .diarize to a ctm file produced by asr | |
""" | |
ctm_df = pd.read_csv( | |
ctm_file, | |
delimiter=" ", | |
names=["utt", "channel", "start", "offset", "word", "confidence"], | |
) | |
ctm_df["end"] = ctm_df["start"] + ctm_df["offset"] | |
starts = ctm_df["start"].values | |
ends = ctm_df["end"].values | |
words = ctm_df["word"].values | |
# Get the earliest start from either diar output or asr output | |
earliest_start = np.min([ctm_df["start"].values[0], segments[0]["start"]]) | |
worded_segments = self.join_samespeaker_segments(segments) | |
worded_segments[0]["start"] = earliest_start | |
cutoffs = [] | |
for seg in worded_segments: | |
end_idx = np.searchsorted(ctm_df["end"].values, seg["end"], side="left") - 1 | |
cutoffs.append(end_idx) | |
indexes = [[0, cutoffs[0]]] | |
for c in cutoffs[1:]: | |
indexes.append([indexes[-1][-1], c]) | |
indexes[-1][-1] = len(words) | |
final_segments = [] | |
for i, seg in enumerate(worded_segments): | |
s_idx, e_idx = indexes[i] | |
words = ctm_df["word"].values[s_idx:e_idx] | |
seg["words"] = " ".join(words) | |
if len(words) >= 1: | |
final_segments.append(seg) | |
else: | |
print( | |
"Removed segment between {} and {} as no words were matched".format( | |
seg["start"], seg["end"] | |
) | |
) | |
return final_segments | |
def nice_text_output(worded_segments, outfile): | |
with open(outfile, "w") as fp: | |
for seg in worded_segments: | |
fp.write( | |
"[{} to {}] Speaker {}: \n".format( | |
round(seg["start"], 2), round(seg["end"], 2), seg["label"] | |
) | |
) | |
fp.write("{}\n\n".format(seg["words"])) | |
class DiarizationPipeline: | |
def __init__(self, device=None): | |
super(DiarizationPipeline, self).__init__() | |
self.diar = Diarizer( | |
device=device, | |
embed_model='ecapa', # supported types: ['xvec', 'ecapa'] | |
cluster_method='ahc', # supported types: ['ahc', 'sc'] | |
window=1, # size of window to extract embeddings (in seconds) | |
period=0.1 # hop of window (in seconds) | |
) | |
def save_speaker_audios(self, segments: list, audio_path: str): | |
""" | |
:param segments: result diarization timestamps | |
:param audio_path: | |
:return: out_wav_paths: list of audio paths | |
""" | |
signal, sr = librosa.load(audio_path, sr=None, mono=True) | |
out_wav_paths = [] | |
segments = pd.DataFrame(segments) | |
segments = self.filter_small_speech(segments) | |
sort_labels = segments.groupby(['label'])['duration'].sum().nlargest(len(set(segments.label))).index | |
for indx, label in enumerate(sort_labels): | |
temp_df = segments[segments.label == label] | |
output_signal = [] | |
for _, r in temp_df.iterrows(): | |
start = int(r["start"] * sr) | |
end = int(r["end"] * sr) | |
output_signal.append(signal[start:end]) | |
out_wav_path = audio_path.replace('.wav', f'_{indx}.wav') | |
sf.write(out_wav_path, np.concatenate(output_signal), sr) | |
out_wav_paths.append(out_wav_path) | |
return out_wav_paths | |
def filter_small_speech(self, segments): | |
segments['duration'] = segments.end - segments.start | |
durs = segments.groupby('label').sum() | |
labels = durs[durs['duration'] / durs.sum()['duration'] > 0.015].index | |
return segments[segments.label.isin(labels)] | |
def __call__(self, input_wav_path: str)-> dict: | |
segments = self.diar.diarize(input_wav_path, | |
num_speakers=None, | |
threshold=9e-1, ) | |
if segments != {}: | |
output_wav_paths = self.save_speaker_audios(segments, input_wav_path) | |
return {'count_speakers': max([i['label'] for i in segments]) + 1, 'diarization_segments': segments, | |
'output_diar_audio_paths': output_wav_paths} | |
else: | |
return {} | |
if __name__ == '__main__': | |
diarization = DiarizationPipeline(device='cuda:0') | |
diarization('../dialog.mp3') | |