File size: 2,381 Bytes
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import nltk  # Here to have a nice missing dependency error message early on
from filelock import FileLock
from transformers.utils import is_offline_mode

try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    if is_offline_mode():
        raise LookupError(
            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
        )
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)


def string_to_float(string, default=-1.0):
    """Converts string to float, using default when conversion not possible."""
    try:
        return float(string)
    except ValueError:
        return default


def string_to_int(string, default=-1):
    """Converts string to int, using default when conversion not possible."""
    try:
        return int(string)
    except ValueError:
        return default


def get_post_processor(task):
    """Returns post processor required to apply on the predictions/targets
    before computing metrics for each task."""
    if task == "stsb":
        return string_to_float
    elif task in ["qqp", "cola", "mrpc"]:
        return string_to_int
    else:
        return None


def postprocess_text_for_metric(metric, preds, labels=None, sources=None):
    if metric == "sari":
        assert sources is not None
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]
        sources = [source.strip() for source in sources]
        return preds, labels, sources
    elif metric == "rouge":
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]
        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
        return preds, labels
    elif metric == "bleu":
        preds = [pred.strip() for pred in preds]
        labels = [[label.strip()] for label in labels]
        return preds, labels
    elif metric in ["bertscore", "bertscore_them"]:
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]
        return preds, labels
    elif metric in ["dist"]:
        preds = [pred.strip() for pred in preds]
        return preds
    else:
        raise NotImplementedError