# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. # -*- coding:utf-8 -*- import os import logging import bleu import weighted_ngram_match import syntax_match import dataflow_match def calc_codebleu(predictions, references, lang, tokenizer=None, params='0.25,0.25,0.25,0.25'): """_summary_ Args: predictions (list[str]): list of predictions references (list[str]): list of lists with references lang (str): ['java','js','c_sharp','php','go','python','ruby'] tokenizer (callable): tokenizer function, Defaults to lambda s: s.split() params (str, optional): Defaults to '0.25,0.25,0.25,0.25'. """ alpha, beta, gamma, theta = [float(x) for x in params.split(',')] # preprocess inputs references = [[x.strip() for x in ref] for ref in references] hypothesis = [x.strip() for x in predictions] if not len(references) == len(hypothesis): raise ValueError # calculate ngram match (BLEU) if tokenizer is None: tokenizer = lambda s: s.split() tokenized_hyps = [tokenizer(x) for x in hypothesis] tokenized_refs = [[tokenizer(x) for x in reference] for reference in references] ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps) # calculate weighted ngram match keywords = [x.strip() for x in open(os.path.abspath(os.path.dirname(__file__)) + '/keywords/' + lang + '.txt', 'r', encoding='utf-8').readlines()] def make_weights(reference_tokens, key_word_list): return {token: 1 if token in key_word_list else 0.2 for token in reference_tokens} tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)] for reference_tokens in reference] for reference in tokenized_refs] weighted_ngram_match_score = weighted_ngram_match.corpus_bleu( tokenized_refs_with_weights, tokenized_hyps) # calculate syntax match syntax_match_score = syntax_match.corpus_syntax_match( references, hypothesis, lang) # calculate dataflow match dataflow_match_score = dataflow_match.corpus_dataflow_match( references, hypothesis, lang) # print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'. # format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score)) code_bleu_score = alpha*ngram_match_score\ + beta*weighted_ngram_match_score\ + gamma*syntax_match_score\ + theta*dataflow_match_score # print('CodeBLEU score: ', code_bleu_score) return { 'CodeBLEU': code_bleu_score, 'ngram_match_score': ngram_match_score, 'weighted_ngram_match_score': weighted_ngram_match_score, 'syntax_match_score': syntax_match_score, 'dataflow_match_score': dataflow_match_score }