Spaces:
Runtime error
Runtime error
# Copyright 2020 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from collections import defaultdict | |
from pathlib import Path | |
import pandas as pd | |
from rouge_cli import calculate_rouge_path | |
from utils import calculate_rouge | |
PRED = [ | |
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the' | |
' final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe' | |
" depression\" German airline confirms it knew of Andreas Lubitz's depression years before he took control.", | |
"The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal" | |
" accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's" | |
" founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the" | |
" body.", | |
"Amnesty International releases its annual report on the death penalty. The report catalogs the use of" | |
" state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the" | |
" world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital" | |
" punishment.", | |
] | |
TGT = [ | |
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .' | |
' Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz' | |
" had informed his Lufthansa training school of an episode of severe depression, airline says .", | |
"Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June ." | |
" Israel and the United States opposed the move, which could open the door to war crimes investigations against" | |
" Israelis .", | |
"Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to" | |
" death . Organization claims that governments around the world are using the threat of terrorism to advance" | |
" executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death" | |
" sentences up by 28% .", | |
] | |
def test_disaggregated_scores_are_determinstic(): | |
no_aggregation = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2", "rougeL"]) | |
assert isinstance(no_aggregation, defaultdict) | |
no_aggregation_just_r2 = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2"]) | |
assert ( | |
pd.DataFrame(no_aggregation["rouge2"]).fmeasure.mean() | |
== pd.DataFrame(no_aggregation_just_r2["rouge2"]).fmeasure.mean() | |
) | |
def test_newline_cnn_improvement(): | |
k = "rougeLsum" | |
score = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=[k])[k] | |
score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=[k])[k] | |
assert score > score_no_sep | |
def test_newline_irrelevant_for_other_metrics(): | |
k = ["rouge1", "rouge2", "rougeL"] | |
score_sep = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=k) | |
score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=k) | |
assert score_sep == score_no_sep | |
def test_single_sent_scores_dont_depend_on_newline_sep(): | |
pred = [ | |
"Her older sister, Margot Frank, died in 1945, a month earlier than previously thought.", | |
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .', | |
] | |
tgt = [ | |
"Margot Frank, died in 1945, a month earlier than previously thought.", | |
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of' | |
" the final seconds on board Flight 9525.", | |
] | |
assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False) | |
def test_pegasus_newline(): | |
pred = [ | |
"""" "a person who has such a video needs to immediately give it to the investigators," prosecutor says .<n> "it is a very disturbing scene," editor-in-chief of bild online tells "erin burnett: outfront" """ | |
] | |
tgt = [ | |
""" Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says .""" | |
] | |
prev_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"], newline_sep=False)["rougeLsum"] | |
new_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"])["rougeLsum"] | |
assert new_score > prev_score | |
def test_rouge_cli(): | |
data_dir = Path("examples/seq2seq/test_data/wmt_en_ro") | |
metrics = calculate_rouge_path(data_dir.joinpath("test.source"), data_dir.joinpath("test.target")) | |
assert isinstance(metrics, dict) | |
metrics_default_dict = calculate_rouge_path( | |
data_dir.joinpath("test.source"), data_dir.joinpath("test.target"), bootstrap_aggregation=False | |
) | |
assert isinstance(metrics_default_dict, defaultdict) | |