File size: 6,580 Bytes
9fab08e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
import time

from torch.utils.data import DataLoader
from datetime import datetime as dt
from params import *
from models.model import ModelWrapper
from utils.metrics import get_mned_metric_from_TruePredict, get_metric_from_TrueWrongPredictV3
from utils.logger import get_logger
from models.sampler import RandomBatchSampler, BucketBatchSampler
from termcolor import colored
import re

class Corrector:
    def __init__(self, model_wrapper: ModelWrapper):
        self.model_name = model_wrapper.model_name
        self.model = model_wrapper.model
        self.model_wrapper = model_wrapper
        self.logger = get_logger("./log/test.log")

        self.device = DEVICE

        self.model.to(self.device)
        self.logger.log(f"Device: {self.device}")
        self.logger.log("Loaded model")
    
    def correct_transfomer_with_tr(self, batch, num_beams = 2):
        correction = dict()
        original = batch
        splits = re.findall("\w\S+\w|\w+|[^\w\s]{1}", batch)
        batch = " ".join(splits)
        with torch.no_grad():
            self.model.eval()
            batch_infer_start = time.time()
            batch = self.model_wrapper.collator.collate([[None, batch,None, None]], type = "correct")
            
            result = self.model.inference(batch['batch_src'], num_beams = num_beams,
                 tokenAligner=self.model_wrapper.collator.tokenAligner)
            correction['predict_text'] = result
            correction['noised_text'] = batch['noised_texts']
            correction['original_text'] = original

            total_infer_time = time.time() - batch_infer_start
            correction['time'] = total_infer_time

        t = re.sub(r"(\s*)([.,:?!;]{1})(\s*)", r"\2\3", correction['predict_text'][0])
        t = re.sub(r"((?P<parenthesis>\()\s)", r"\g<parenthesis>", t)
        t = re.sub(r"(\s(?P<parenthesis>\)))", r"\g<parenthesis>", t)
        t = re.sub(r"((?P<bracket>\[)\s)", r"\g<bracket>", t)
        t = re.sub(r"(\s(?P<bracket>\]))", r"\g<bracket>", t)
        t = re.sub(r"([\'\"])\s(.*)\s([\'\"])", r"\1\2\3", t)
        correction['predict_text']= re.sub(r"\s(%)", "%", t)
        return correction
    
    def _get_transfomer_with_tr_generations(self, batch, num_beams = 2):
        correction = dict()
        with torch.no_grad():
            self.model.eval()
            batch_infer_start = time.time()
            
            result = self.model.inference(batch['batch_src'], num_beams = num_beams,
                 tokenAligner=self.model_wrapper.collator.tokenAligner)

            correction['predict_text'] = result
            correction['noised_text'] = batch['noised_texts']

            total_infer_time = time.time() - batch_infer_start
            correction['time'] = total_infer_time

        return correction

    def step(self, batch, num_beams = 2):
        outputs= self._get_transfomer_with_tr_generations(batch, num_beams)
        batch_predictions = outputs['predict_text']
        batch_label_texts = batch['label_texts']
        batch_noised_texts = batch['noised_texts']

        return batch_predictions, batch_noised_texts, batch_label_texts

    def _evaluation_loop_autoregressive(self, data_loader, num_beams = 2):
        TP, FP, FN = 0, 0, 0
        MNED = 0.0
        O_MNED = 0.0
        total_infer_time = 0.0
        twp_logger = get_logger(f"./log/true_wrong_predict{time.time()}.log")
        with torch.no_grad():

            self.model.eval()

            for step, batch in enumerate(data_loader):

                batch_infer_start = time.time()

                batch_predictions, batch_noised_texts, batch_label_texts = \
                        self.step(batch, num_beams = num_beams)

                batch_infer_time = time.time() - batch_infer_start

                _TP, _FP, _FN = get_metric_from_TrueWrongPredictV3(batch_label_texts, batch_noised_texts, batch_predictions, self.model_wrapper.tokenAligner.vocab, twp_logger)

                TP += _TP
                FP += _FP
                FN += _FN

                _MNED = get_mned_metric_from_TruePredict(batch_label_texts, batch_predictions)
                MNED += _MNED

                _O_MNED = get_mned_metric_from_TruePredict(batch_label_texts, batch_noised_texts)
                O_MNED += _O_MNED

                info = '{} - Evaluate - iter: {:08d}/{:08d} - TP: {} - FP: {} - FN: {} - _MNED: {:.5f} - _O_MNED: {:.5f} - {} time: {:.2f}s'.format(
                    dt.now(),
                    step,
                    self.test_iters,
                    _TP,
                    _FP,
                    _FN,
                    _MNED,
                    _O_MNED,
                    self.device,
                    batch_infer_time)

                self.logger.log(info)

                torch.cuda.empty_cache()
                total_infer_time += time.time() - batch_infer_start
        return total_infer_time, TP, FP, FN, MNED / len(data_loader), O_MNED / len(data_loader)
        
    def evaluate(self, dataset, beams: int = None):

        def test_collate_wrapper(batch):
            return self.model_wrapper.collator.collate(batch, type = "test")

        if not BUCKET_SAMPLING:
            self.test_sampler = RandomBatchSampler(dataset, VALID_BATCH_SIZE, shuffle = False)
        else:
            self.test_sampler = BucketBatchSampler(dataset, shuffle = True)

        data_loader = DataLoader(dataset=dataset,batch_sampler= self.test_sampler,\
            collate_fn=test_collate_wrapper)
            
        self.test_iters = len(data_loader)

        assert beams != None
        total_infer_time, TP, FP, FN, MNED, O_MNED = self._evaluation_loop_autoregressive(data_loader, num_beams = beams)
            

        self.logger.log("Total inference time for this data is: {:4f} secs".format(total_infer_time))
        self.logger.log("###############################################")

        info = f"Metrics for Auto-Regressive with Beam Search number {beams}"
        self.logger.log(colored(info, "green"))

        dc_TP = TP
        dc_FP = FP
        dc_FN = FN

        dc_precision = dc_TP / (dc_TP + dc_FP)
        dc_recall = dc_TP / (dc_TP + dc_FN)
        dc_F1 = 2. * dc_precision * dc_recall/ ((dc_precision + dc_recall) + 1e-8)

        self.logger.log(f"TP: {TP}. FP: {FP}. FN: {FN}")

        self.logger.log(f"Precision: {dc_precision}")
        self.logger.log(f"Recall: {dc_recall}")
        self.logger.log(f"F1: {dc_F1}")
        self.logger.log(f"MNED: {MNED}")
        self.logger.log(f"O_MNED: {O_MNED}")

        return