|
import torch.nn as nn |
|
import torch |
|
|
|
import sacrebleu |
|
from tqdm import tqdm |
|
import numpy as np |
|
from loguru import logger |
|
|
|
def evaluate(model:nn.Module, config, mode, no_unk=False, beam_search=False): |
|
exp_name = config.config['EXP'] |
|
logger.add(f"./logs/{exp_name}.log", rotation="500 MB") |
|
model.eval() |
|
eval_data_path = '' |
|
if mode == 'test': |
|
eval_data_path = './data/testing.txt' |
|
else: |
|
eval_data_path = './data/validation.txt' |
|
none_scores = [] |
|
exp_scores = [] |
|
floor_scores = [] |
|
average_scores = [] |
|
with open(eval_data_path, 'r', encoding='utf-8') as f: |
|
for line in tqdm(f): |
|
src, tgt = line.strip().split('\t') |
|
none_score, exp_score, floor_score, average_score = evaluate_one_sentence(model, src, tgt, no_unk, beam_search) |
|
none_scores.append(none_score) |
|
exp_scores.append(exp_score) |
|
floor_scores.append(floor_score) |
|
average_scores.append(average_score) |
|
none_score = np.mean(none_scores) |
|
exp_score = np.mean(exp_scores) |
|
floor_score = np.mean(floor_scores) |
|
average_score = np.mean(average_scores) |
|
logger.info(f"none_score:{none_score}, exp_score:{exp_score}, floor_score:{floor_score}, average_score:{average_score}") |
|
return none_score, exp_score, floor_score, average_score |
|
|
|
|
|
def evaluate_one_sentence(model:nn.Module, sentence:str, reference:str, no_unk, beam_search): |
|
if not no_unk: |
|
if not beam_search: |
|
translated_sentence_list, translated_sentence = model.translate(sentence) |
|
else: |
|
translated_sentence_list, translated_sentence = model.translate_no_unk_beam_search(sentence) |
|
else: |
|
if not beam_search: |
|
translated_sentence_list, translated_sentence = model.translate_no_unk(sentence) |
|
else: |
|
translated_sentence_list, translated_sentence = model.translate_beam_search(sentence) |
|
logger.info(f'ref:{reference}') |
|
logger.info(f'translate:{translated_sentence}') |
|
|
|
none_score = sacrebleu.sentence_bleu(translated_sentence, [reference], smooth_method='none').score |
|
|
|
exp_score = sacrebleu.sentence_bleu(translated_sentence, [reference]).score |
|
|
|
floor_score = sacrebleu.sentence_bleu(translated_sentence, [reference], smooth_method='floor').score |
|
|
|
average_score = (none_score + exp_score + floor_score)/3 |
|
logger.info(f'none:{none_score} exp:{exp_score} floor:{floor_score}') |
|
return none_score, exp_score, floor_score, average_score |
|
|
|
|
|
|
|
|
|
|