|
|
|
|
|
|
|
|
|
import os |
|
import logging |
|
import bleu |
|
import weighted_ngram_match |
|
import syntax_match |
|
import dataflow_match |
|
|
|
|
|
def calc_codebleu(predictions, references, lang, tokenizer=None, params='0.25,0.25,0.25,0.25'): |
|
"""_summary_ |
|
|
|
Args: |
|
predictions (list[str]): list of predictions |
|
references (list[str]): list of lists with references |
|
lang (str): ['java','js','c_sharp','php','go','python','ruby'] |
|
tokenizer (callable): tokenizer function, Defaults to lambda s: s.split() |
|
params (str, optional): Defaults to '0.25,0.25,0.25,0.25'. |
|
""" |
|
|
|
alpha, beta, gamma, theta = [float(x) for x in params.split(',')] |
|
|
|
|
|
references = [[x.strip() for x in ref] for ref in references] |
|
hypothesis = [x.strip() for x in predictions] |
|
|
|
if not len(references) == len(hypothesis): |
|
raise ValueError |
|
|
|
|
|
if tokenizer is None: |
|
tokenizer = lambda s: s.split() |
|
|
|
tokenized_hyps = [tokenizer(x) for x in hypothesis] |
|
tokenized_refs = [[tokenizer(x) for x in reference] |
|
for reference in references] |
|
|
|
ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps) |
|
|
|
|
|
keywords = [x.strip() for x in open(os.path.abspath(os.path.dirname(__file__)) + '/keywords/' + lang + |
|
'.txt', 'r', encoding='utf-8').readlines()] |
|
|
|
def make_weights(reference_tokens, key_word_list): |
|
return {token: 1 if token in key_word_list else 0.2 |
|
for token in reference_tokens} |
|
tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)] |
|
for reference_tokens in reference] for reference in tokenized_refs] |
|
|
|
weighted_ngram_match_score = weighted_ngram_match.corpus_bleu( |
|
tokenized_refs_with_weights, tokenized_hyps) |
|
|
|
|
|
syntax_match_score = syntax_match.corpus_syntax_match( |
|
references, hypothesis, lang) |
|
|
|
|
|
dataflow_match_score = dataflow_match.corpus_dataflow_match( |
|
references, hypothesis, lang) |
|
|
|
|
|
|
|
|
|
code_bleu_score = alpha*ngram_match_score\ |
|
+ beta*weighted_ngram_match_score\ |
|
+ gamma*syntax_match_score\ |
|
+ theta*dataflow_match_score |
|
|
|
|
|
|
|
return { |
|
'CodeBLEU': code_bleu_score, |
|
'ngram_match_score': ngram_match_score, |
|
'weighted_ngram_match_score': weighted_ngram_match_score, |
|
'syntax_match_score': syntax_match_score, |
|
'dataflow_match_score': dataflow_match_score |
|
} |
|
|