codebleu / my_codebleu.py
idsedykh's picture
some work
0a17ff4
raw
history blame
2.97 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding:utf-8 -*-
import os
import logging
import bleu
import weighted_ngram_match
import syntax_match
import dataflow_match
def calc_codebleu(predictions, references, lang, tokenizer=None, params='0.25,0.25,0.25,0.25'):
"""_summary_
Args:
predictions (list[str]): list of predictions
references (list[str]): list of lists with references
lang (str): ['java','js','c_sharp','php','go','python','ruby']
tokenizer (callable): tokenizer function, Defaults to lambda s: s.split()
params (str, optional): Defaults to '0.25,0.25,0.25,0.25'.
"""
alpha, beta, gamma, theta = [float(x) for x in params.split(',')]
# preprocess inputs
references = [[x.strip() for x in ref] for ref in references]
hypothesis = [x.strip() for x in predictions]
if not len(references) == len(hypothesis):
raise ValueError
# calculate ngram match (BLEU)
if tokenizer is None:
tokenizer = lambda s: s.split()
tokenized_hyps = [tokenizer(x) for x in hypothesis]
tokenized_refs = [[tokenizer(x) for x in reference]
for reference in references]
ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps)
# calculate weighted ngram match
keywords = [x.strip() for x in open(os.path.abspath(os.path.dirname(__file__)) + '/keywords/' + lang +
'.txt', 'r', encoding='utf-8').readlines()]
def make_weights(reference_tokens, key_word_list):
return {token: 1 if token in key_word_list else 0.2
for token in reference_tokens}
tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)]
for reference_tokens in reference] for reference in tokenized_refs]
weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(
tokenized_refs_with_weights, tokenized_hyps)
# calculate syntax match
syntax_match_score = syntax_match.corpus_syntax_match(
references, hypothesis, lang)
# calculate dataflow match
dataflow_match_score = dataflow_match.corpus_dataflow_match(
references, hypothesis, lang)
# print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'.
# format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score))
code_bleu_score = alpha*ngram_match_score\
+ beta*weighted_ngram_match_score\
+ gamma*syntax_match_score\
+ theta*dataflow_match_score
# print('CodeBLEU score: ', code_bleu_score)
return {
'CodeBLEU': code_bleu_score,
'ngram_match_score': ngram_match_score,
'weighted_ngram_match_score': weighted_ngram_match_score,
'syntax_match_score': syntax_match_score,
'dataflow_match_score': dataflow_match_score
}