File size: 738 Bytes
8bc0b6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from sklearn.metrics import f1_score, recall_score
from zeno import ZenoOptions, distill, metric


@metric
def accuracy(df, ops: ZenoOptions):
    if len(df) == 0:
        return 0
    return 100 * (df[ops.label_column] == df[ops.output_column]).sum() / len(df)


@metric
def recall(df, ops: ZenoOptions):
    return 100 * recall_score(
        df[ops.label_column], df[ops.output_column], average="macro"
    )


@metric
def f1(df, ops: ZenoOptions):
    return 100 * f1_score(df[ops.label_column], df[ops.output_column], average="macro")


@distill
def correct(df, ops: ZenoOptions):
    return (df[ops.label_column] == df[ops.output_column]).tolist()


@distill
def output_label(df, ops: ZenoOptions):
    return df[ops.output_column]