File size: 4,618 Bytes
dd1cb8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import sklearn
import numpy as np
import matplotlib.pyplot as plt

# Args:
# y_true (ndarray): Ground truth labels (0 or 1).
# y_pred (ndarray): Predicted labels (0 or 1).

# Event-based metrics
def event_metrics(y_true, y_pred, tolerance, overlap_threshold=0.7, switch=False):
    if switch:
        y_pred = (1 - y_pred > 0)
    else:
        y_pred = y_pred > 0
    # Create empty list for storing true events
    true_events = []
    # Initilize start index
    start = None
    for i, label in enumerate(y_true):
        if label == 1 and start is None:
            start = i
        elif label == 0 and start is not None:
            true_events.append((start, i - 1))
            start = None
    
    if start is not None:
        true_events.append((start, len(y_true) - 1))

    pred_events = []
    start = None
    for i, label in enumerate(y_pred):
        if label == 1 and start is None:
            start = i
        elif label == 0 and start is not None:
            pred_events.append((start, i - 1))
            start = None

    if start is not None:
        pred_events.append((start, len(pred_events) - 1))
        

    # Highlight overlapping events
    # Intialize true positive and overlap events
    tp, fp, fn = 0, 0, 0
    counted_events = []
    fake_events = []
    undetected_events = []
    pred_check = pred_events[:]

    iou_list = []
    for true_event in true_events:
        tp_event = 0
        for pred_event in pred_events:
            lower_bound = true_event[0] - tolerance
            upper_bound = true_event[1] + tolerance
            # Calculate overlap rate
            overlap_rate = 0
            if lower_bound <= pred_event[0] and upper_bound >= pred_event[1]:
                overlap_start = max(true_event[0], pred_event[0])
                overlap_end = min(true_event[1], pred_event[1])
                overlap_length = overlap_end - overlap_start + 1
                true_length = true_event[1] - true_event[0] + 1
                pred_length = pred_event[1] - pred_event[0] + 1
                overlap_rate = overlap_length / min(true_length, pred_length)

            # Range check
            if overlap_rate >= overlap_threshold:
                union_start = min(true_event[0], pred_event[0])
                union_end = max(true_event[1], pred_event[1])
                union_length = union_end - union_start + 1
                iou = overlap_length / union_length
                iou_list.append(iou)
                # True positive: correctly detected events
                if pred_event in pred_check:
                    pred_check.remove(pred_event)
                    if tp_event == 0:
                        tp_event = 1
                        counted_events.append((true_event[0], true_event[1]))

        # False negative: events in true label that have not been correctly detected according to the definition
        if tp_event == 0:
            fn += 1
            undetected_events.append((true_event[0], true_event[1]))

        tp += tp_event
    # False positive: events in prediction that are not correct according to the definition
    if pred_check:
        for pred_event in pred_check: 
            if pred_event[1] - pred_event[0] > tolerance:
                fp += 1
                fake_events.append((pred_event[0], pred_event[1]))

    if tp == 0 and fn == 0 and fp == 0:
        F = 1
    else:
        # Calculation of F-Score
        P = tp / (tp + fp) if (tp + fp) != 0 else 0
        R = tp / (tp + fn) if (tp + fn) != 0 else 0
        F = 2 * P * R / (P + R) if (P + R) != 0 else 0
    # Calculation of IOU
    if iou_list == []:
        IOU = 0
    else:    
        IOU = np.mean(iou_list)
    return F, IOU, counted_events, fake_events, undetected_events

def event_visualization(y_true, y_pred, counted_events, fake_events, undetected_events):
    # Create the plot
    plt.figure(figsize=(10, 6))
    plt.plot(range(len(y_true)), y_true, label='True Label')
    plt.plot(range(len(y_pred)), y_pred, label='Predicted Label')

    for event in counted_events:
        plt.axvspan(event[0], event[1], alpha=0.3, color='green', label='Overlap event')
    for event in fake_events:
        plt.axvspan(event[0], event[1], alpha=0.3, color='red', label='Fake event')
    for event in undetected_events:
        plt.axvspan(event[0], event[1], alpha=0.3, color='blue', label='Undetected event')

    # Add labels and title
    plt.xlabel('Index')
    plt.ylabel('Label')
    plt.title('Overlapping Events Visualization')
    plt.legend()
    plt.grid(True)
    plt.show()

# Reference: https://doi.org/10.3390/app6060162