lmzjms's picture
Upload 1162 files
0b32ad6 verified
"""
Metrics for the slot filling SLU task
Authors:
* Yung-Sung Chuang 2021
* Heng-Jui Chang 2022
"""
import re
from typing import Dict, List, Tuple
from .common import cer, wer
__all__ = ["slot_type_f1", "slot_value_cer", "slot_value_wer", "slot_edit_f1"]
def clean(ref: str) -> str:
ref = re.sub(r"B\-(\S+) ", "", ref)
ref = re.sub(r" E\-(\S+)", "", ref)
return ref
def parse(hyp: str, ref: str) -> Tuple[str, str, str, str]:
gex = re.compile(r"B\-(\S+) (.+?) E\-\1")
hyp = re.sub(r" +", " ", hyp)
ref = re.sub(r" +", " ", ref)
hyp_slots = gex.findall(hyp)
ref_slots = gex.findall(ref)
ref_slots = ";".join([":".join([x[1], x[0]]) for x in ref_slots])
if len(hyp_slots) > 0:
hyp_slots = ";".join([":".join([clean(x[1]), x[0]]) for x in hyp_slots])
else:
hyp_slots = ""
ref = clean(ref)
hyp = clean(hyp)
return ref, hyp, ref_slots, hyp_slots
def get_slot_dict(
hyp: str, ref: str
) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
ref_text, hyp_text, ref_slots, hyp_slots = parse(hyp, ref)
ref_slots = ref_slots.split(";")
hyp_slots = hyp_slots.split(";")
ref_dict, hyp_dict = {}, {}
if ref_slots[0] != "":
for ref_slot in ref_slots:
v, k = ref_slot.split(":")
ref_dict.setdefault(k, [])
ref_dict[k].append(v)
if hyp_slots[0] != "":
for hyp_slot in hyp_slots:
v, k = hyp_slot.split(":")
hyp_dict.setdefault(k, [])
hyp_dict[k].append(v)
return ref_dict, hyp_dict
def slot_type_f1(hypothesis: List[str], groundtruth: List[str], **kwargs) -> float:
F1s = []
for p, t in zip(hypothesis, groundtruth):
ref_dict, hyp_dict = get_slot_dict(p, t)
# Slot Type F1 evaluation
if len(hyp_dict.keys()) == 0 and len(ref_dict.keys()) == 0:
F1 = 1.0
elif len(hyp_dict.keys()) == 0:
F1 = 0.0
elif len(ref_dict.keys()) == 0:
F1 = 0.0
else:
P, R = 0.0, 0.0
for slot in ref_dict:
if slot in hyp_dict:
R += 1
R = R / len(ref_dict.keys())
for slot in hyp_dict:
if slot in ref_dict:
P += 1
P = P / len(hyp_dict.keys())
F1 = 2 * P * R / (P + R) if (P + R) > 0 else 0.0
F1s.append(F1)
return sum(F1s) / len(F1s)
def slot_value_cer(hypothesis: List[str], groundtruth: List[str], **kwargs) -> float:
value_hyps, value_refs = [], []
for p, t in zip(hypothesis, groundtruth):
ref_dict, hyp_dict = get_slot_dict(p, t)
# Slot Value WER/CER evaluation
unique_slots = list(ref_dict.keys())
for slot in unique_slots:
for ref_i, ref_v in enumerate(ref_dict[slot]):
if slot not in hyp_dict:
hyp_v = ""
value_refs.append(ref_v)
value_hyps.append(hyp_v)
else:
min_cer = 100
best_hyp_v = ""
for hyp_v in hyp_dict[slot]:
tmp_cer = cer([hyp_v], [ref_v])
if min_cer > tmp_cer:
min_cer = tmp_cer
best_hyp_v = hyp_v
value_refs.append(ref_v)
value_hyps.append(best_hyp_v)
return cer(value_hyps, value_refs)
def slot_value_wer(hypothesis: List[str], groundtruth: List[str], **kwargs) -> float:
value_hyps = []
value_refs = []
for p, t in zip(hypothesis, groundtruth):
ref_dict, hyp_dict = get_slot_dict(p, t)
# Slot Value WER/CER evaluation
unique_slots = list(ref_dict.keys())
for slot in unique_slots:
for ref_i, ref_v in enumerate(ref_dict[slot]):
if slot not in hyp_dict:
hyp_v = ""
value_refs.append(ref_v)
value_hyps.append(hyp_v)
else:
min_wer = 100
best_hyp_v = ""
for hyp_v in hyp_dict[slot]:
tmp_wer = wer([hyp_v], [ref_v])
if min_wer > tmp_wer:
min_wer = tmp_wer
best_hyp_v = hyp_v
value_refs.append(ref_v)
value_hyps.append(best_hyp_v)
return wer(value_hyps, value_refs)
def slot_edit_f1(
hypothesis: List[str], groundtruth: List[str], loop_over_all_slot: bool, **kwargs
) -> float:
slot2F1 = {} # defaultdict(lambda: [0,0,0]) # TPs, FNs, FPs
for p, t in zip(hypothesis, groundtruth):
ref_dict, hyp_dict = get_slot_dict(p, t)
# Collecting unique slots
unique_slots = list(ref_dict.keys())
if loop_over_all_slot:
unique_slots += [x for x in hyp_dict if x not in ref_dict]
# Evaluating slot edit F1
for slot in unique_slots:
TP = 0
FP = 0
FN = 0
if slot not in ref_dict: # this never happens in list(ref_dict.keys())
for hyp_v in hyp_dict[slot]:
FP += 1
else:
for ref_i, ref_v in enumerate(ref_dict[slot]):
if slot not in hyp_dict:
FN += 1
else:
match = False
for hyp_v in hyp_dict[slot]:
# if ref_i < len(hyp_dict[slot]):
# hyp_v = hyp_dict[slot][ref_i]
if hyp_v == ref_v:
match = True
break
if match:
TP += 1
else:
FN += 1
FP += 1
slot2F1.setdefault(slot, [0, 0, 0])
slot2F1[slot][0] += TP
slot2F1[slot][1] += FN
slot2F1[slot][2] += FP
all_TPs, all_FNs, all_FPs = 0, 0, 0
for slot in slot2F1.keys():
all_TPs += slot2F1[slot][0]
all_FNs += slot2F1[slot][1]
all_FPs += slot2F1[slot][2]
return 2 * all_TPs / (2 * all_TPs + all_FPs + all_FNs)
def slot_edit_f1_full(hypothesis: List[str], groundtruth: List[str], **kwargs) -> float:
return slot_edit_f1(hypothesis, groundtruth, loop_over_all_slot=True, **kwargs)
def slot_edit_f1_part(hypothesis: List[str], groundtruth: List[str], **kwargs) -> float:
return slot_edit_f1(hypothesis, groundtruth, loop_over_all_slot=False, **kwargs)