File size: 2,179 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from .rec_metric import RecMetric


class RecGTCMetric(object):

    def __init__(self,
                 main_indicator='acc',
                 is_filter=False,
                 ignore_space=True,
                 stream=False,
                 with_ratio=False,
                 max_len=25,
                 max_ratio=4,
                 **kwargs):
        self.main_indicator = main_indicator
        self.is_filter = is_filter
        self.ignore_space = ignore_space
        self.eps = 1e-5
        self.gtc_metric = RecMetric(main_indicator=main_indicator,
                                    is_filter=is_filter,
                                    ignore_space=ignore_space,
                                    stream=stream,
                                    with_ratio=with_ratio,
                                    max_len=max_len,
                                    max_ratio=max_ratio)
        self.ctc_metric = RecMetric(main_indicator=main_indicator,
                                    is_filter=is_filter,
                                    ignore_space=ignore_space,
                                    stream=stream,
                                    with_ratio=with_ratio,
                                    max_len=max_len,
                                    max_ratio=max_ratio)

    def __call__(self,
                 pred_label,
                 batch=None,
                 training=False,
                 *args,
                 **kwargs):

        ctc_metric = self.ctc_metric(pred_label[1], batch, training=training)
        gtc_metric = self.gtc_metric(pred_label[0], batch, training=training)
        ctc_metric['gtc_acc'] = gtc_metric['acc']
        ctc_metric['gtc_norm_edit_dis'] = gtc_metric['norm_edit_dis']
        return ctc_metric

    def get_metric(self):
        """
        return metrics {
                 'acc': 0,
                 'norm_edit_dis': 0,
            }
        """
        ctc_metric = self.ctc_metric.get_metric()
        gtc_metric = self.gtc_metric.get_metric()
        ctc_metric['gtc_acc'] = gtc_metric['acc']
        ctc_metric['gtc_norm_edit_dis'] = gtc_metric['norm_edit_dis']
        return ctc_metric