Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Script to compute official BLEU score. | |
Source: | |
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py | |
""" | |
import collections | |
import math | |
import re | |
import sys | |
import unicodedata | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
class UnicodeRegex(object): | |
"""Ad-hoc hack to recognize all punctuation and symbols.""" | |
def __init__(self): | |
punctuation = self.property_chars("P") | |
self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])") | |
self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])") | |
self.symbol_re = re.compile("([" + self.property_chars("S") + "])") | |
def property_chars(self, prefix): | |
return "".join( | |
chr(x) | |
for x in range(sys.maxunicode) | |
if unicodedata.category(chr(x)).startswith(prefix)) | |
uregex = UnicodeRegex() | |
def bleu_tokenize(string): | |
r"""Tokenize a string following the official BLEU implementation. | |
See https://github.com/moses-smt/mosesdecoder/' | |
'blob/master/scripts/generic/mteval-v14.pl#L954-L983 | |
In our case, the input string is expected to be just one line | |
and no HTML entities de-escaping is needed. | |
So we just tokenize on punctuation and symbols, | |
except when a punctuation is preceded and followed by a digit | |
(e.g. a comma/dot as a thousand/decimal separator). | |
Note that a numer (e.g. a year) followed by a dot at the end of sentence | |
is NOT tokenized, | |
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` | |
does not match this case (unless we add a space after each sentence). | |
However, this error is already in the original mteval-v14.pl | |
and we want to be consistent with it. | |
Args: | |
string: the input string | |
Returns: | |
a list of tokens | |
""" | |
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string) | |
string = uregex.punct_nondigit_re.sub(r" \1 \2", string) | |
string = uregex.symbol_re.sub(r" \1 ", string) | |
return string.split() | |
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False): | |
"""Compute BLEU for two files (reference and hypothesis translation).""" | |
ref_lines = tf.io.gfile.GFile(ref_filename).read().strip().splitlines() | |
hyp_lines = tf.io.gfile.GFile(hyp_filename).read().strip().splitlines() | |
return bleu_on_list(ref_lines, hyp_lines, case_sensitive) | |
def _get_ngrams_with_counter(segment, max_order): | |
"""Extracts all n-grams up to a given maximum order from an input segment. | |
Args: | |
segment: text segment from which n-grams will be extracted. | |
max_order: maximum length in tokens of the n-grams returned by this methods. | |
Returns: | |
The Counter containing all n-grams upto max_order in segment | |
with a count of how many times each n-gram occurred. | |
""" | |
ngram_counts = collections.Counter() | |
for order in range(1, max_order + 1): | |
for i in range(0, len(segment) - order + 1): | |
ngram = tuple(segment[i:i + order]) | |
ngram_counts[ngram] += 1 | |
return ngram_counts | |
def compute_bleu(reference_corpus, | |
translation_corpus, | |
max_order=4, | |
use_bp=True): | |
"""Computes BLEU score of translated segments against one or more references. | |
Args: | |
reference_corpus: list of references for each translation. Each reference | |
should be tokenized into a list of tokens. | |
translation_corpus: list of translations to score. Each translation should | |
be tokenized into a list of tokens. | |
max_order: Maximum n-gram order to use when computing BLEU score. | |
use_bp: boolean, whether to apply brevity penalty. | |
Returns: | |
BLEU score. | |
""" | |
reference_length = 0 | |
translation_length = 0 | |
bp = 1.0 | |
geo_mean = 0 | |
matches_by_order = [0] * max_order | |
possible_matches_by_order = [0] * max_order | |
precisions = [] | |
for (references, translations) in zip(reference_corpus, translation_corpus): | |
reference_length += len(references) | |
translation_length += len(translations) | |
ref_ngram_counts = _get_ngrams_with_counter(references, max_order) | |
translation_ngram_counts = _get_ngrams_with_counter(translations, max_order) | |
overlap = dict((ngram, min(count, translation_ngram_counts[ngram])) | |
for ngram, count in ref_ngram_counts.items()) | |
for ngram in overlap: | |
matches_by_order[len(ngram) - 1] += overlap[ngram] | |
for ngram in translation_ngram_counts: | |
possible_matches_by_order[len(ngram) - | |
1] += translation_ngram_counts[ngram] | |
precisions = [0] * max_order | |
smooth = 1.0 | |
for i in range(0, max_order): | |
if possible_matches_by_order[i] > 0: | |
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i] | |
if matches_by_order[i] > 0: | |
precisions[i] = float( | |
matches_by_order[i]) / possible_matches_by_order[i] | |
else: | |
smooth *= 2 | |
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i]) | |
else: | |
precisions[i] = 0.0 | |
if max(precisions) > 0: | |
p_log_sum = sum(math.log(p) for p in precisions if p) | |
geo_mean = math.exp(p_log_sum / max_order) | |
if use_bp: | |
ratio = translation_length / reference_length | |
bp = 0. if ratio < 1e-6 else math.exp(1 - | |
1. / ratio) if ratio < 1.0 else 1.0 | |
bleu = geo_mean * bp | |
return np.float32(bleu) | |
def bleu_on_list(ref_lines, hyp_lines, case_sensitive=False): | |
"""Compute BLEU for two list of strings (reference and hypothesis).""" | |
if len(ref_lines) != len(hyp_lines): | |
raise ValueError( | |
"Reference and translation files have different number of " | |
"lines (%d VS %d). If training only a few steps (100-200), the " | |
"translation may be empty." % (len(ref_lines), len(hyp_lines))) | |
if not case_sensitive: | |
ref_lines = [x.lower() for x in ref_lines] | |
hyp_lines = [x.lower() for x in hyp_lines] | |
ref_tokens = [bleu_tokenize(x) for x in ref_lines] | |
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines] | |
return compute_bleu(ref_tokens, hyp_tokens) * 100 | |