File size: 2,558 Bytes
f8bd4d2
 
e79bbcf
f8bd4d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch.nn as nn
import torch

import sacrebleu
from tqdm import tqdm
import numpy as np
from loguru import logger

def evaluate(model:nn.Module, config, mode, no_unk=False, beam_search=False):
    exp_name = config.config['EXP']
    logger.add(f"./logs/{exp_name}.log", rotation="500 MB")
    model.eval()
    eval_data_path = ''
    if mode == 'test':
        eval_data_path = './data/testing.txt'
    else:
        eval_data_path = './data/validation.txt'
    none_scores = []
    exp_scores = []
    floor_scores = []
    average_scores = []
    with open(eval_data_path, 'r', encoding='utf-8') as f:
        for line in tqdm(f):
            src, tgt = line.strip().split('\t')
            none_score, exp_score, floor_score, average_score = evaluate_one_sentence(model, src, tgt, no_unk, beam_search)
            none_scores.append(none_score)
            exp_scores.append(exp_score)
            floor_scores.append(floor_score)
            average_scores.append(average_score)
    none_score = np.mean(none_scores)
    exp_score = np.mean(exp_scores)
    floor_score = np.mean(floor_scores)
    average_score = np.mean(average_scores)
    logger.info(f"none_score:{none_score}, exp_score:{exp_score}, floor_score:{floor_score}, average_score:{average_score}")
    return none_score, exp_score, floor_score, average_score
    
    
def evaluate_one_sentence(model:nn.Module, sentence:str, reference:str, no_unk, beam_search):
    if not no_unk:
        if not beam_search:
            translated_sentence_list, translated_sentence = model.translate(sentence)
        else:
            translated_sentence_list, translated_sentence = model.translate_no_unk_beam_search(sentence)
    else:
        if not beam_search:
            translated_sentence_list, translated_sentence = model.translate_no_unk(sentence)
        else:
            translated_sentence_list, translated_sentence = model.translate_beam_search(sentence)
    logger.info(f'ref:{reference}')
    logger.info(f'translate:{translated_sentence}')
    # none
    none_score = sacrebleu.sentence_bleu(translated_sentence, [reference], smooth_method='none').score
    # exp
    exp_score = sacrebleu.sentence_bleu(translated_sentence, [reference]).score
    # floor
    floor_score = sacrebleu.sentence_bleu(translated_sentence, [reference], smooth_method='floor').score
    
    average_score = (none_score + exp_score + floor_score)/3
    logger.info(f'none:{none_score} exp:{exp_score} floor:{floor_score}')
    return none_score, exp_score, floor_score, average_score