Spaces:
Sleeping
Sleeping
File size: 5,551 Bytes
9fab08e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from utils.logger import get_logger
import numpy as np
from rapidfuzz.distance.Levenshtein import normalized_distance
import multiprocessing
import time
import utils.alignment as alignment
def _get_mned_metric_from_TruePredict(true_text, predict_text):
return normalized_distance(predict_text, true_text)
def get_mned_metric_from_TruePredict(batch_true_text, batch_predict_text):
total_NMED = 0.0
count = 0
for true_text, predict_text in zip(batch_true_text, batch_predict_text):
total_NMED += _get_mned_metric_from_TruePredict(true_text, predict_text)
count += 1
return total_NMED / count
def get_metric_for_tfm(batch_predicts, batch_targets, batch_length):
num_correct, num_wrong = 0, 0
for predict, target, length in zip(batch_predicts, batch_targets, batch_length):
predict = predict[1:-1]
target = target[1:-1]
predict = np.array(predict[0:length])
target = np.array(target[0:length])
num_correct += np.sum(predict == target)
num_wrong += np.sum(predict != target)
return num_correct, num_wrong
def allign_seq2trueseq(seq, true_seq, gap_symbol = "-"):
prev_sep = None
next_sep = None
seq_list = []
true_list = []
accumulate_true_word = ""
accumulate_pred_word = ""
assert len(true_seq) == len(seq)
for i in range(len(true_seq)):
if true_seq[i] != " ":
accumulate_true_word += true_seq[i]
accumulate_pred_word += seq[i]
else:
if seq[i] == gap_symbol:
next_sep = gap_symbol
if prev_sep != None and prev_sep == gap_symbol:
accumulate_pred_word = "@@" + accumulate_pred_word
if next_sep != None and next_sep == gap_symbol:
accumulate_pred_word = accumulate_pred_word + "@@"
else:
next_sep = " "
if prev_sep != None and prev_sep == gap_symbol:
accumulate_pred_word = "@@" + accumulate_pred_word
if next_sep != None and next_sep == gap_symbol:
accumulate_pred_word = accumulate_pred_word + "@@"
true_list.append(accumulate_true_word.replace(gap_symbol, ""))
seq_list.append(accumulate_pred_word)
accumulate_pred_word = ""
accumulate_true_word = ""
prev_sep = next_sep
next_sep = None
return seq_list, true_list
def align_2seq2trueseq(wrong_text, pred_text, true_text, gap_symbol = "-"):
assert gap_symbol != None and len(gap_symbol) == 1
seq1, true_seq = alignment.needle(wrong_text, true_text, gap_symbol)
seq1_list, true_list = allign_seq2trueseq(seq1, true_seq, gap_symbol)
seq2, true_seq = alignment.needle(pred_text, true_text, gap_symbol)
seq2_list, _ = allign_seq2trueseq(seq2, true_seq, gap_symbol)
return list(zip(seq1_list, seq2_list, true_list))
def _get_metric_from_TrueWrongPredictV3(true_text, wrong_text, predict_text, vocab = None):
gap_symbol = None
if vocab != None:
all_symbols = set(list(vocab.chartoken2idx.keys())[4:])
symbols = set(list(wrong_text + predict_text + true_text))
usable_symbols = all_symbols.difference(symbols)
assert len(usable_symbols) > 0
if "-" not in usable_symbols:
gap_symbol = usable_symbols.pop()
else:
gap_symbol = "-"
gap_symbol = gap_symbol if gap_symbol != None else "-"
alignment = align_2seq2trueseq(wrong_text, predict_text, true_text, gap_symbol)
TP, FP, FN = 0, 0, 0
for wrong, predict, true in alignment:
if wrong == true:
if predict[:-2] == true:
pass
elif predict != true:
if len(predict.split(" ")) == len(true.split(" ")):
FP += 1
else:
penalty = len(predict.split(" ")) - len(true.split(" "))
assert penalty > 0
FP += penalty
else:
if predict == true:
TP += 1
else:
if len(predict.split(" ")) == len(true.split(" ")):
FN += 1
else:
penalty = len(predict.split(" ")) - len(true.split(" "))
assert penalty > 0
FN += penalty
return TP, FP, FN
def worker_task(true_text, wrong_text, predict_text, vocab):
_TP, _FP, _FN = _get_metric_from_TrueWrongPredictV3(true_text, wrong_text, predict_text, vocab)
return (_TP, _FP, _FN)
from multiprocessing import Pool
def get_metric_from_TrueWrongPredictV3(batch_true_text, batch_wrong_text, batch_predict_text, vocab, twp_logger):
assert vocab != None
TPs, FPs, FNs = 0, 0, 0
with Pool(int(multiprocessing.cpu_count() / 3)) as pool:
data = [(true_text, wrong_text, pred_text, vocab) for true_text, wrong_text, pred_text in zip(batch_true_text, batch_wrong_text, batch_predict_text)]
result = pool.starmap_async(worker_task, data)
for i, result in enumerate(result.get()):
TPs += result[0]
FPs += result[1]
FNs += result[2]
if twp_logger:
twp_logger.log(batch_true_text[i], file_only=True)
twp_logger.log(batch_wrong_text[i], file_only=True)
twp_logger.log(batch_predict_text[i], file_only=True)
twp_logger.log(f"{result[0]} - {result[1]} - {result[2]}", file_only=True)
return TPs, FPs, FNs |