Suburst's picture
Update src/eval.py
e79bbcf verified
import torch.nn as nn
import torch
import sacrebleu
from tqdm import tqdm
import numpy as np
from loguru import logger
def evaluate(model:nn.Module, config, mode, no_unk=False, beam_search=False):
exp_name = config.config['EXP']
logger.add(f"./logs/{exp_name}.log", rotation="500 MB")
model.eval()
eval_data_path = ''
if mode == 'test':
eval_data_path = './data/testing.txt'
else:
eval_data_path = './data/validation.txt'
none_scores = []
exp_scores = []
floor_scores = []
average_scores = []
with open(eval_data_path, 'r', encoding='utf-8') as f:
for line in tqdm(f):
src, tgt = line.strip().split('\t')
none_score, exp_score, floor_score, average_score = evaluate_one_sentence(model, src, tgt, no_unk, beam_search)
none_scores.append(none_score)
exp_scores.append(exp_score)
floor_scores.append(floor_score)
average_scores.append(average_score)
none_score = np.mean(none_scores)
exp_score = np.mean(exp_scores)
floor_score = np.mean(floor_scores)
average_score = np.mean(average_scores)
logger.info(f"none_score:{none_score}, exp_score:{exp_score}, floor_score:{floor_score}, average_score:{average_score}")
return none_score, exp_score, floor_score, average_score
def evaluate_one_sentence(model:nn.Module, sentence:str, reference:str, no_unk, beam_search):
if not no_unk:
if not beam_search:
translated_sentence_list, translated_sentence = model.translate(sentence)
else:
translated_sentence_list, translated_sentence = model.translate_no_unk_beam_search(sentence)
else:
if not beam_search:
translated_sentence_list, translated_sentence = model.translate_no_unk(sentence)
else:
translated_sentence_list, translated_sentence = model.translate_beam_search(sentence)
logger.info(f'ref:{reference}')
logger.info(f'translate:{translated_sentence}')
# none
none_score = sacrebleu.sentence_bleu(translated_sentence, [reference], smooth_method='none').score
# exp
exp_score = sacrebleu.sentence_bleu(translated_sentence, [reference]).score
# floor
floor_score = sacrebleu.sentence_bleu(translated_sentence, [reference], smooth_method='floor').score
average_score = (none_score + exp_score + floor_score)/3
logger.info(f'none:{none_score} exp:{exp_score} floor:{floor_score}')
return none_score, exp_score, floor_score, average_score