Spaces:
Sleeping
Sleeping
File size: 2,381 Bytes
17ff0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import nltk # Here to have a nice missing dependency error message early on
from filelock import FileLock
from transformers.utils import is_offline_mode
try:
nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
if is_offline_mode():
raise LookupError(
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
)
with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)
def string_to_float(string, default=-1.0):
"""Converts string to float, using default when conversion not possible."""
try:
return float(string)
except ValueError:
return default
def string_to_int(string, default=-1):
"""Converts string to int, using default when conversion not possible."""
try:
return int(string)
except ValueError:
return default
def get_post_processor(task):
"""Returns post processor required to apply on the predictions/targets
before computing metrics for each task."""
if task == "stsb":
return string_to_float
elif task in ["qqp", "cola", "mrpc"]:
return string_to_int
else:
return None
def postprocess_text_for_metric(metric, preds, labels=None, sources=None):
if metric == "sari":
assert sources is not None
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
sources = [source.strip() for source in sources]
return preds, labels, sources
elif metric == "rouge":
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
elif metric == "bleu":
preds = [pred.strip() for pred in preds]
labels = [[label.strip()] for label in labels]
return preds, labels
elif metric in ["bertscore", "bertscore_them"]:
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
return preds, labels
elif metric in ["dist"]:
preds = [pred.strip() for pred in preds]
return preds
else:
raise NotImplementedError
|