Spaces:
Running
Running
from .rec_metric import RecMetric | |
class RecGTCMetric(object): | |
def __init__(self, | |
main_indicator='acc', | |
is_filter=False, | |
ignore_space=True, | |
stream=False, | |
with_ratio=False, | |
max_len=25, | |
max_ratio=4, | |
**kwargs): | |
self.main_indicator = main_indicator | |
self.is_filter = is_filter | |
self.ignore_space = ignore_space | |
self.eps = 1e-5 | |
self.gtc_metric = RecMetric(main_indicator=main_indicator, | |
is_filter=is_filter, | |
ignore_space=ignore_space, | |
stream=stream, | |
with_ratio=with_ratio, | |
max_len=max_len, | |
max_ratio=max_ratio) | |
self.ctc_metric = RecMetric(main_indicator=main_indicator, | |
is_filter=is_filter, | |
ignore_space=ignore_space, | |
stream=stream, | |
with_ratio=with_ratio, | |
max_len=max_len, | |
max_ratio=max_ratio) | |
def __call__(self, | |
pred_label, | |
batch=None, | |
training=False, | |
*args, | |
**kwargs): | |
ctc_metric = self.ctc_metric(pred_label[1], batch, training=training) | |
gtc_metric = self.gtc_metric(pred_label[0], batch, training=training) | |
ctc_metric['gtc_acc'] = gtc_metric['acc'] | |
ctc_metric['gtc_norm_edit_dis'] = gtc_metric['norm_edit_dis'] | |
return ctc_metric | |
def get_metric(self): | |
""" | |
return metrics { | |
'acc': 0, | |
'norm_edit_dis': 0, | |
} | |
""" | |
ctc_metric = self.ctc_metric.get_metric() | |
gtc_metric = self.gtc_metric.get_metric() | |
ctc_metric['gtc_acc'] = gtc_metric['acc'] | |
ctc_metric['gtc_norm_edit_dis'] = gtc_metric['norm_edit_dis'] | |
return ctc_metric | |