File size: 11,246 Bytes
dd1cb8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import numpy as np
import librosa
import torch
import matplotlib.pyplot as plt

from metrics.event_based_metrics import event_metrics
from src.audio_preprocessing import readLabels, object_padding, fbank_features_extraction
from tslearn.clustering import TimeSeriesKMeans, KShape
from tslearn.metrics import dtw
from sklearn.cluster import KMeans, AffinityPropagation, AgglomerativeClustering, MeanShift, estimate_bandwidth, DBSCAN, OPTICS, Birch
from sklearn.metrics import accuracy_score, f1_score

import os
os.environ["OMP_NUM_THREADS"] = '3'

def standardize_array(array):
    mean = np.mean(array, axis=0)
    std = np.std(array, axis=0)

    # Avoid division by zero
    std[std == 0] = 1
    standardized_array = (array - mean) / std
    return standardized_array

# Functions to cluster and label a single audio's frames
def kmeans_clustering(audio_data, n_clusters=2):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(audio_data)
    labels = kmeans.predict(audio_data)
    return labels

def dtw_kmedoids_clustering(audio_data, n_clusters=2):
    km = TimeSeriesKMeans(n_clusters=n_clusters, metric="dtw", max_iter=10, random_state=42)
    labels = km.fit_predict(audio_data)
    return labels

def kshape_clustering(audio_data, n_clusters=2):
    ks = KShape(n_clusters=n_clusters, max_iter=10, random_state=42)
    labels = ks.fit_predict(audio_data)
    return labels

def affinity_propagation_clustering(audio_data):
    af = AffinityPropagation(random_state=42)  
    labels = af.fit_predict(audio_data)
    return labels

def agglomerative_clustering(audio_data, n_clusters=2):
    agg = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage='average') 
    distances = [[dtw(x, y) for y in audio_data] for x in audio_data]
    labels = agg.fit_predict(distances)
    return labels

def mean_shift_clustering(audio_data):
    bandwidth = estimate_bandwidth(audio_data, quantile=0.2) 
    ms = MeanShift(bandwidth=bandwidth, bin_seeding=True, cluster_all=False)
    labels = ms.fit_predict(audio_data)
    return labels

def bisecting_kmeans_clustering(audio_data, n_clusters=2):
    clusters = [audio_data]
    while len(clusters) < n_clusters:
        largest_cluster_idx = max(range(len(clusters)), key=lambda i: len(clusters[i]))
        largest_cluster = clusters[largest_cluster_idx]

        km = TimeSeriesKMeans(n_clusters=2, metric="dtw", max_iter=10, random_state=42)
        sub_labels = km.fit_predict(largest_cluster)
        sub_cluster1 = largest_cluster[sub_labels == 0]
        sub_cluster2 = largest_cluster[sub_labels == 1]

        clusters.pop(largest_cluster_idx)
        clusters.append(sub_cluster1)
        clusters.append(sub_cluster2)

    labels = [-1] * len(audio_data) 
    for i, cluster in enumerate(clusters):
        for idx in [j for j, x in enumerate(audio_data) if x in cluster]:
            labels[idx] = i

    return np.array(labels)

def dbscan_clustering(audio_data, eps=0.5, min_samples=5):
    dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=dtw)
    labels = dbscan.fit_predict(audio_data)
    return labels

def optics_clustering(audio_data, min_samples=5):
    optics = OPTICS(min_samples=min_samples, metric=dtw, cluster_method='xi')
    labels = optics.fit_predict(audio_data)
    return labels

def birch_clustering(audio_data, n_clusters=None, branching_factor=50, threshold=0.5):
    birch = Birch(n_clusters=n_clusters, branching_factor=branching_factor, threshold=threshold)
    labels = birch.fit_predict(audio_data)
    return labels

def clustering_predicting(model, annotation_file, audio_file, max_length, clustering_method="kmeans", k=2):
    signal, fs = librosa.load(audio_file)
    signal = object_padding(signal, max_length)
    truth_labels = readLabels(path=annotation_file, sample_rate=fs)
    truth_labels = object_padding(truth_labels, max_length)
     
    test_audio = fbank_features_extraction([audio_file], max_length)

    test_input = torch.tensor(test_audio, dtype=torch.float32)
    x_recon, test_latent, test_u, loss = model(test_input)

    clustering_input = standardize_array(test_u.reshape((703, -1)).detach().numpy())
    clustering_label = None

    if clustering_method == "kmeans":
        clustering_label = kmeans_clustering(clustering_input, n_clusters=k)
    elif clustering_method == "dtw":
        clustering_label = dtw_kmedoids_clustering(clustering_input, n_clusters=k)
    elif clustering_method == "kshape":
        clustering_label = kshape_clustering(clustering_input, n_clusters=k)
    elif clustering_method == "affinity":
        clustering_label = affinity_propagation_clustering(clustering_input)
    elif clustering_method == "agglomerative":
        clustering_label = agglomerative_clustering(clustering_input, n_clusters=k)
    elif clustering_method == "mean_shift":
        clustering_label = mean_shift_clustering(clustering_input)
    elif clustering_method == "bisecting":
        clustering_label = bisecting_kmeans_clustering(clustering_input, n_clusters=k)
    elif clustering_method == "DBSCAN":
        clustering_label = dbscan_clustering(clustering_input)
    elif clustering_method == "OPTICS":
        clustering_label = optics_clustering(clustering_input)
    elif clustering_method == "Birch":
        clustering_label = birch_clustering(clustering_input, n_clusters=k)

    label_timeseries = np.zeros(max_length)
    begin = int(0)
    end = int(0.025 *fs)
    shift_step = int(0.01 * fs)
    for i in range(clustering_label.shape[0]):
        label_timeseries[begin:end] = abs(clustering_label[i])
        begin = begin + shift_step
        end = end + shift_step

    return signal, fs, np.array(truth_labels), label_timeseries

def signal_visualization(signal, fs, truth_labels, label_timeseries):
    # define time axis
    Ns = len(signal) # number of sample
    Ts = 1 / fs # sampling period
    t = np.arange(Ns) * Ts # time axis in seconds
    norm_coef = 1.1 * np.max(signal)
    edge_ind = np.min([signal.shape[0], len(truth_labels)])
     
    plt.figure(figsize=(24, 6))
    plt.plot(t[:edge_ind], signal[:edge_ind])
    plt.plot(t[:edge_ind], truth_labels[:edge_ind] * norm_coef)
    plt.plot(t[:edge_ind], label_timeseries[:edge_ind] * norm_coef)

    plt.title("Ground truth labels")
    plt.legend(['Signal', 'Cry', 'Clusters'])
    plt.show()

def cluster_visualization(signal, fs, truth_labels, label_timeseries):
    # define time axis
    Ns = len(signal)  # number of sample
    Ts = 1 / fs  # sampling period
    t = np.arange(Ns) * Ts  # time axis in seconds
    norm_coef = 1.1 * np.max(signal)
    edge_ind = np.min([signal.shape[0], len(truth_labels)])

    plt.figure(figsize=(24, 6))
    line_signal, = plt.plot(t[:edge_ind], signal[:edge_ind])

    # Identify 'cry' and 'non-cry' segments
    cry_indices = np.where(truth_labels == 1)[0]
    non_cry_indices = np.where(truth_labels == 0)[0]

    # Fill rectangular segments for 'cry' and 'non-cry'
    # Identify start and end points for each continuous segment
    start_cry = np.insert(np.where(np.diff(cry_indices) != 1)[0] + 1, 0, 0)
    end_cry = np.append(np.where(np.diff(cry_indices) != 1)[0], len(cry_indices) - 1)

    # Fill rectangular segments
    for start, end in zip(start_cry, end_cry):
        plt.fill_between(
            t[cry_indices[start:end+1]], 
            0, 
            norm_coef, 
            color='orange', 
            alpha=0.5,  # Adjust transparency as needed
            label='Cry' if start == start_cry[0] else None  # Avoid duplicate labels
        )
    legend_handles = []
    legend_handles.append(plt.Rectangle((0, 0), 1, 1, color='orange', alpha=0.5))

    start_non_cry = np.insert(np.where(np.diff(non_cry_indices) != 1)[0] + 1, 0, 0)
    end_non_cry = np.append(np.where(np.diff(non_cry_indices) != 1)[0], len(non_cry_indices) - 1)   

    # Fill rectangular segments
    for start, end in zip(start_non_cry, end_non_cry):
        plt.fill_between(
            t[non_cry_indices[start:end+1]], 
            0, 
            norm_coef, 
            color='gray', 
            alpha=0.5,  # Adjust transparency as needed
            label='Non-cry' if start == start_non_cry[0] else None  # Avoid duplicate labels
        )
    legend_handles.append(plt.Rectangle((0, 0), 1, 1, color='gray', alpha=0.5))

    # Get unique values in label_timeseries to assign distinct colors
    unique_labels = np.unique(label_timeseries)
    cmap = plt.get_cmap('tab10')  # You can choose other colormaps as needed

    # Fill rectangular segments for each unique label
    for i, label in enumerate(unique_labels):
        label_indices = np.where(label_timeseries == label)[0] 
        
        # Identify start and end points for each continuous segment
        start_indices = np.insert(np.where(np.diff(label_indices) != 1)[0] + 1, 0, 0)
        end_indices = np.append(np.where(np.diff(label_indices) != 1)[0], len(label_indices) - 1)

        # Fill rectangular segments
        for start, end in zip(start_indices, end_indices):
            plt.fill_between(
                t[label_indices[start:end+1]], 
                0, 
                -norm_coef, 
                color=cmap(i), 
                alpha=0.5,  # Adjust transparency as needed
                label=f'Cluster {label}' if start == start_indices[0] else None  # Avoid duplicate labels
            )
        legend_handles.append(plt.Rectangle((0, 0), 1, 1, color=cmap(i), alpha=0.5))

    plt.title("Audio Clustering")
    plt.legend(
        [line_signal] + legend_handles, 
        ['Signal'] + ['Cry', 'Non-Cry'] + [f'Cluster {label}' for label in unique_labels]
    )
    plt.show()

def clustering_evaluatation(model, max_length, audio_files, annotation_files, domain_index=None, clustering_method="kmeans", k=2):
    acc_list, framef_list, eventf_list, iou_list = [], [], [], []
    switch_list = []
    if domain_index is None:
        domain_index = range(len(audio_files))
    for i in domain_index:
        annotation_file = annotation_files[i]
        audio_file = audio_files[i]
        clustering_switch = False
        _, _, truth_labels, label_timeseries = clustering_predicting(model, annotation_file, audio_file, max_length, clustering_method, k)
        temp_accuracy = accuracy_score(truth_labels, label_timeseries)
        framef = max(f1_score(1 - label_timeseries > 0, truth_labels), f1_score(label_timeseries > 0, truth_labels))
        if temp_accuracy < 0.5:
            clustering_accuracy = 1-temp_accuracy
            clustering_switch = True
        else:
            clustering_accuracy = temp_accuracy
        acc_list.append(clustering_accuracy)
        switch_list.append(clustering_switch)
        framef_list.append(framef)
        eventf, iou, _, _, _ = event_metrics(truth_labels, label_timeseries, tolerance=2000, overlap_threshold=0.75)
        eventf_list.append(eventf)
        iou_list.append(iou)

    return acc_list, framef_list, eventf_list, iou_list, switch_list