import torch.nn as nn import torch import sacrebleu from tqdm import tqdm import numpy as np from loguru import logger def evaluate(model:nn.Module, config, mode, no_unk=False, beam_search=False): exp_name = config.config['EXP'] logger.add(f"./logs/{exp_name}.log", rotation="500 MB") model.eval() eval_data_path = '' if mode == 'test': eval_data_path = './data/testing.txt' else: eval_data_path = './data/validation.txt' none_scores = [] exp_scores = [] floor_scores = [] average_scores = [] with open(eval_data_path, 'r', encoding='utf-8') as f: for line in tqdm(f): src, tgt = line.strip().split('\t') none_score, exp_score, floor_score, average_score = evaluate_one_sentence(model, src, tgt, no_unk, beam_search) none_scores.append(none_score) exp_scores.append(exp_score) floor_scores.append(floor_score) average_scores.append(average_score) none_score = np.mean(none_scores) exp_score = np.mean(exp_scores) floor_score = np.mean(floor_scores) average_score = np.mean(average_scores) logger.info(f"none_score:{none_score}, exp_score:{exp_score}, floor_score:{floor_score}, average_score:{average_score}") return none_score, exp_score, floor_score, average_score def evaluate_one_sentence(model:nn.Module, sentence:str, reference:str, no_unk, beam_search): if not no_unk: if not beam_search: translated_sentence_list, translated_sentence = model.translate(sentence) else: translated_sentence_list, translated_sentence = model.translate_no_unk_beam_search(sentence) else: if not beam_search: translated_sentence_list, translated_sentence = model.translate_no_unk(sentence) else: translated_sentence_list, translated_sentence = model.translate_beam_search(sentence) logger.info(f'ref:{reference}') logger.info(f'translate:{translated_sentence}') # none none_score = sacrebleu.sentence_bleu(translated_sentence, [reference], smooth_method='none').score # exp exp_score = sacrebleu.sentence_bleu(translated_sentence, [reference]).score # floor floor_score = sacrebleu.sentence_bleu(translated_sentence, [reference], smooth_method='floor').score average_score = (none_score + exp_score + floor_score)/3 logger.info(f'none:{none_score} exp:{exp_score} floor:{floor_score}') return none_score, exp_score, floor_score, average_score