File size: 10,706 Bytes
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import numpy as np
import torch
import torch.nn.functional as F

from sdlm.metrics.metrics import distinct_n_grams, mauve, zipf
from sdlm.metrics.perplexity import conditional_perplexity, perplexity
from sdlm.metrics.repetition import repetition
from sdlm.utils import convert_to_simplex, join_texts


def sample_logits(sampling_type, logits, top_p, temperature):
    # top-p (nucleus) sampling.
    if sampling_type == "top_p":
        logits = logits / temperature
        probs = F.softmax(logits, dim=-1)
        if top_p is not None:
            sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)

            # Remove tokens with cumulative probability above the threshold.
            sorted_indices_to_keep = cumsum_probs < top_p

            # Shift the indices to the right to keep also the first token below the threshold.
            sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
            sorted_indices_to_keep[..., 0] = 1

            indices_to_keep = sorted_indices_to_keep.scatter(
                dim=2, index=sorted_indices, src=sorted_indices_to_keep
            )
            filtered_logits = logits.masked_fill(indices_to_keep == 0, -float("Inf"))

            # sample from the filtered distribution.
            token_ids = torch.distributions.categorical.Categorical(
                logits=filtered_logits
            ).sample()
        else:
            token_ids = torch.argmax(probs, dim=-1)
    else:
        assert NotImplementedError
    return token_ids


def remove_first_occurrence(string, char):
    # We do not strip as we need the spaces as well.
    if char in string:
        idx = string.index(char)
        string = string[idx + len(char) :]
    return string


def keep_till_first_occurrence(string, chars):
    """Given a list of characters, trim the text after the first occurance between them."""
    idxs = [string.index(char) for char in chars if char in string]
    if len(idxs):
        min_idx = np.min(idxs)
        string = string[:min_idx]
    return string


def process_text(texts):
    # TODO(rabeeh): for now we only cover roberta case.
    texts = [keep_till_first_occurrence(text, ["</s>"]) for text in texts]
    texts = [remove_first_occurrence(text, "<s>") for text in texts]
    return texts


def split_into_masked_and_unmasked(token_ids, span_mask, return_masked=None):
    """Given an span_mask, splits the given token_ids into masked and unmasked parts.

    If return_masked is set, only returns the masked parts, if this is set to False,
    only returns the unmasked parts, and If set to None, returns both parts.
    """
    masked = []
    unmasked = []
    prev_mask = span_mask[0]
    span = []

    def update_spans(span, masked, unmasked, mask):
        # TODO: this needs to be here for previous version of the codes.
        # span = torch.stack(span)
        masked.append(span) if mask else unmasked.append(span)

    for _, (token_id, mask) in enumerate(zip(token_ids, span_mask)):
        if mask == prev_mask:
            span.append(token_id)
        else:
            # Adds the previous span.
            update_spans(span, masked, unmasked, prev_mask)
            prev_mask = mask
            span = [token_id]
    # Adds the last span.
    update_spans(span, masked, unmasked, prev_mask)

    if return_masked is None:
        return masked, unmasked

    return masked if return_masked else unmasked


def concatenate_alternatively(longer, shorter, mark=""):
    """Given two lists of strings, concatenates them alternatively.

    We assume that the concatenated string should starts from elements in the longer
    list (which has one extra element). The shorter text can optionally be embraced with
    a `mark` text on both sides.
    """
    concatenated_str = ""
    for long, short in zip(longer, shorter):
        concatenated_str += long + " " + mark + short + mark + " "
    if len(longer) == len(shorter) + 1:
        return concatenated_str + longer[-1]
    elif len(longer) == len(shorter):
        return concatenated_str[:-1]
    else:
        raise ValueError


def aggregate_list(x):
    str = ""
    if len(x) == 0:
        return str
    for line in x:
        str += line + " "
    return str[:-1]


def logits_projection(logits, sampling_type, top_p, simplex_value, temperature):
    # TODO(rabeeh): huggingface has different sampling, like constrastive one.
    # also there are more variant in diffusion-lm.
    token_ids = sample_logits(sampling_type, logits, top_p, temperature)
    return convert_to_simplex(token_ids, simplex_value, vocab_size=logits.shape[2])


def filter_empty(texts):
    """Filters empty texts and return the remained texts and the their indices."""
    list_of_tuples = [(text, i) for i, text in enumerate(texts) if text != ""]
    if len(list_of_tuples) == 0:
        return [], []
    non_empty_texts, remained_inds = list(zip(*list_of_tuples))
    return list(non_empty_texts), list(remained_inds)


def find(list_: list[int], target: int) -> int:
    """
    Wrapper function around `.index`, which throws a `ValueError` when the target is not in the list.
    """
    try:
        return list_.index(target)
    except ValueError:
        return len(list_)


def predict_conditional_generated(
    span_masks,
    input_ids,
    tokenizer,
    predicted_token_ids,
    prefix_name,
    skip_special_tokens,
):
    masked = list(
        map(
            lambda x, y: split_into_masked_and_unmasked(x, y, return_masked=True),
            predicted_token_ids,
            span_masks,
        )
    )
    unmasked = list(
        map(
            lambda x, y: split_into_masked_and_unmasked(x, y, return_masked=False),
            input_ids,
            span_masks,
        )
    )
    pred_masked_texts = [
        tokenizer.batch_decode(
            x[: find(x, tokenizer.eos_token_id)], skip_special_tokens=True
        )
        for x in masked
    ]
    pred_unmasked_texts = [
        tokenizer.batch_decode(
            x[: find(x, tokenizer.eos_token_id)], skip_special_tokens=True
        )
        for x in unmasked
    ]
    assert tokenizer.eos_token == "</s>"  # guard against model changes in future.
    # remove <s> and </s> tokens if needed.
    if skip_special_tokens:
        pred_masked_texts = [
            [x[0].replace("<s>", "").replace("</s>", "")] for x in pred_masked_texts
        ]
        pred_unmasked_texts = [
            [x[0].replace("<s>", "").replace("</s>", "")] for x in pred_unmasked_texts
        ]
    pred_texts_marked = list(
        map(
            lambda x, y: concatenate_alternatively(x, y, mark="***"),
            pred_unmasked_texts,
            pred_masked_texts,
        )
    )
    aggregated_masked_texts = list(map(lambda x: aggregate_list(x), pred_masked_texts))
    predicted_tokens = [
        np.array(item).tolist() for submasked in masked for item in submasked
    ]
    return {
        # prefix_name: pred_texts,
        prefix_name + "_marked": pred_texts_marked,
        prefix_name + "_masked": aggregated_masked_texts,
        prefix_name + "_masked_tokens": predicted_tokens,
    }


def evaluate_generation(
    results,
    data_args,
    causal_model,
    causal_tokenizer,
    is_conditional_generation,
    prefix_lm_eval=False,
    skip_special_tokens=True,
    eval_for_all_metrics=False,
):
    metrics = {}
    # In case of prefix_lm since the generated text is unified, we can evaluate only the masked parts.
    if prefix_lm_eval:
        gold_text_key = "gold_texts_masked"
        # In case of gpt2, we only have the key of `generated_texts_masked`.
        keys = (
            ["generated_texts_masked"]
            if "generated_texts_masked" in results
            else ["pred_texts_from_simplex_masked", "pred_texts_from_logits_masked"]
        )
    else:
        keys = ["pred_texts_from_simplex", "pred_texts_from_logits"]
        gold_text_key = "gold_texts"

    if is_conditional_generation:
        gold_texts = results[gold_text_key]
        if not skip_special_tokens:
            gold_texts = process_text(gold_texts)
    if "prefixes" in results:
        prefixes = results["prefixes"]
    else:
        prefixes = None

    for key in keys:
        key_metrics = {}
        texts = results[key]
        if not skip_special_tokens:
            texts = process_text(texts)

        non_empty_texts, remained_indices = filter_empty(texts)
        if len(non_empty_texts) == 0:
            continue

        # Perplexity measured by a causal model.
        if prefixes is None:
            key_metrics.update(
                {
                    "perplexity": perplexity(
                        non_empty_texts, causal_model, causal_tokenizer
                    )["mean_perplexity"]
                }
            )
        else:
            non_empty_prefixes = [
                prefix for i, prefix in enumerate(prefixes) if i in remained_indices
            ]
            perplexity_results = conditional_perplexity(
                non_empty_texts, non_empty_prefixes, causal_model, causal_tokenizer
            )
            key_metrics.update(
                {
                    "perplexity": perplexity_results["mean_perplexity"],
                    "total_perplexity": perplexity_results["mean_perplexity_total"],
                }
            )

        # Dist-1,2,3 measurements.
        key_metrics.update(distinct_n_grams(texts))

        # Metrics requiring the gold text.
        if is_conditional_generation and eval_for_all_metrics:
            # Note that we need to pass both context and predicted texts to this metric.
            # remained_gold_texts = [text for i, text in enumerate(gold_texts) if i in remained_indices]
            # remained_prefixes = [text for i, text in enumerate(prefixes) if i in remained_indices]
            texts_with_context = join_texts(prefixes, texts)
            gold_with_context = join_texts(prefixes, gold_texts)
            length = data_args.max_seq_length - data_args.truncation_length
            key_metrics.update(
                mauve(
                    predictions=texts_with_context,
                    references=gold_with_context,
                    length=length,
                )
            )

        if key + "_tokens" in results and eval_for_all_metrics:
            key_metrics.update(repetition(results[key + "_tokens"], causal_tokenizer))
            key_metrics.update(zipf(results[key + "_tokens"]))

        # Adds the metrics.
        key_metrics = {f"{key}_{k}": v for k, v in key_metrics.items()}
        metrics.update(key_metrics)

    return metrics