Spaces:
Sleeping
Sleeping
File size: 6,580 Bytes
9fab08e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch
import time
from torch.utils.data import DataLoader
from datetime import datetime as dt
from params import *
from models.model import ModelWrapper
from utils.metrics import get_mned_metric_from_TruePredict, get_metric_from_TrueWrongPredictV3
from utils.logger import get_logger
from models.sampler import RandomBatchSampler, BucketBatchSampler
from termcolor import colored
import re
class Corrector:
def __init__(self, model_wrapper: ModelWrapper):
self.model_name = model_wrapper.model_name
self.model = model_wrapper.model
self.model_wrapper = model_wrapper
self.logger = get_logger("./log/test.log")
self.device = DEVICE
self.model.to(self.device)
self.logger.log(f"Device: {self.device}")
self.logger.log("Loaded model")
def correct_transfomer_with_tr(self, batch, num_beams = 2):
correction = dict()
original = batch
splits = re.findall("\w\S+\w|\w+|[^\w\s]{1}", batch)
batch = " ".join(splits)
with torch.no_grad():
self.model.eval()
batch_infer_start = time.time()
batch = self.model_wrapper.collator.collate([[None, batch,None, None]], type = "correct")
result = self.model.inference(batch['batch_src'], num_beams = num_beams,
tokenAligner=self.model_wrapper.collator.tokenAligner)
correction['predict_text'] = result
correction['noised_text'] = batch['noised_texts']
correction['original_text'] = original
total_infer_time = time.time() - batch_infer_start
correction['time'] = total_infer_time
t = re.sub(r"(\s*)([.,:?!;]{1})(\s*)", r"\2\3", correction['predict_text'][0])
t = re.sub(r"((?P<parenthesis>\()\s)", r"\g<parenthesis>", t)
t = re.sub(r"(\s(?P<parenthesis>\)))", r"\g<parenthesis>", t)
t = re.sub(r"((?P<bracket>\[)\s)", r"\g<bracket>", t)
t = re.sub(r"(\s(?P<bracket>\]))", r"\g<bracket>", t)
t = re.sub(r"([\'\"])\s(.*)\s([\'\"])", r"\1\2\3", t)
correction['predict_text']= re.sub(r"\s(%)", "%", t)
return correction
def _get_transfomer_with_tr_generations(self, batch, num_beams = 2):
correction = dict()
with torch.no_grad():
self.model.eval()
batch_infer_start = time.time()
result = self.model.inference(batch['batch_src'], num_beams = num_beams,
tokenAligner=self.model_wrapper.collator.tokenAligner)
correction['predict_text'] = result
correction['noised_text'] = batch['noised_texts']
total_infer_time = time.time() - batch_infer_start
correction['time'] = total_infer_time
return correction
def step(self, batch, num_beams = 2):
outputs= self._get_transfomer_with_tr_generations(batch, num_beams)
batch_predictions = outputs['predict_text']
batch_label_texts = batch['label_texts']
batch_noised_texts = batch['noised_texts']
return batch_predictions, batch_noised_texts, batch_label_texts
def _evaluation_loop_autoregressive(self, data_loader, num_beams = 2):
TP, FP, FN = 0, 0, 0
MNED = 0.0
O_MNED = 0.0
total_infer_time = 0.0
twp_logger = get_logger(f"./log/true_wrong_predict{time.time()}.log")
with torch.no_grad():
self.model.eval()
for step, batch in enumerate(data_loader):
batch_infer_start = time.time()
batch_predictions, batch_noised_texts, batch_label_texts = \
self.step(batch, num_beams = num_beams)
batch_infer_time = time.time() - batch_infer_start
_TP, _FP, _FN = get_metric_from_TrueWrongPredictV3(batch_label_texts, batch_noised_texts, batch_predictions, self.model_wrapper.tokenAligner.vocab, twp_logger)
TP += _TP
FP += _FP
FN += _FN
_MNED = get_mned_metric_from_TruePredict(batch_label_texts, batch_predictions)
MNED += _MNED
_O_MNED = get_mned_metric_from_TruePredict(batch_label_texts, batch_noised_texts)
O_MNED += _O_MNED
info = '{} - Evaluate - iter: {:08d}/{:08d} - TP: {} - FP: {} - FN: {} - _MNED: {:.5f} - _O_MNED: {:.5f} - {} time: {:.2f}s'.format(
dt.now(),
step,
self.test_iters,
_TP,
_FP,
_FN,
_MNED,
_O_MNED,
self.device,
batch_infer_time)
self.logger.log(info)
torch.cuda.empty_cache()
total_infer_time += time.time() - batch_infer_start
return total_infer_time, TP, FP, FN, MNED / len(data_loader), O_MNED / len(data_loader)
def evaluate(self, dataset, beams: int = None):
def test_collate_wrapper(batch):
return self.model_wrapper.collator.collate(batch, type = "test")
if not BUCKET_SAMPLING:
self.test_sampler = RandomBatchSampler(dataset, VALID_BATCH_SIZE, shuffle = False)
else:
self.test_sampler = BucketBatchSampler(dataset, shuffle = True)
data_loader = DataLoader(dataset=dataset,batch_sampler= self.test_sampler,\
collate_fn=test_collate_wrapper)
self.test_iters = len(data_loader)
assert beams != None
total_infer_time, TP, FP, FN, MNED, O_MNED = self._evaluation_loop_autoregressive(data_loader, num_beams = beams)
self.logger.log("Total inference time for this data is: {:4f} secs".format(total_infer_time))
self.logger.log("###############################################")
info = f"Metrics for Auto-Regressive with Beam Search number {beams}"
self.logger.log(colored(info, "green"))
dc_TP = TP
dc_FP = FP
dc_FN = FN
dc_precision = dc_TP / (dc_TP + dc_FP)
dc_recall = dc_TP / (dc_TP + dc_FN)
dc_F1 = 2. * dc_precision * dc_recall/ ((dc_precision + dc_recall) + 1e-8)
self.logger.log(f"TP: {TP}. FP: {FP}. FN: {FN}")
self.logger.log(f"Precision: {dc_precision}")
self.logger.log(f"Recall: {dc_recall}")
self.logger.log(f"F1: {dc_F1}")
self.logger.log(f"MNED: {MNED}")
self.logger.log(f"O_MNED: {O_MNED}")
return
|