File size: 5,611 Bytes
0b7b08a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Dict, Sequence, Tuple
import re
import numpy as np
import torch


def postprocess_classification_generation(predictions) -> str:
    return re.split("Prompt|Completion", predictions, 1)[0]


def compute_classification_accuracy(predictions: Sequence[Dict[str, str]]) -> float:
    """Compute the accuracy of a sequence of predictions."""

    def _preprocess_fn(s):
        """Function to preprocess both targets and predictions."""
        return s.lower()

    is_correct = [
        _preprocess_fn(x["prediction"]) == _preprocess_fn(x["class_label"])
        for x in predictions
    ]

    return np.mean(is_correct).item()


def compute_shifted_logits_and_labels(
    logits: torch.Tensor, encodings, tokenizer, eoc_token_id
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Helper function to compute shifted logits and labels.

    This allows for straightforward computation of the loss on shift_logits
    and shift_labels such that the nth element of logits computes the n-1th
    element of the original labels (in the outputs, the nth element of logits
    corresponds to the nth element of the labels).

    Elements in shift_labels that correspond to inputs are masked with values
    of -100 (by default in hf, loss is only computed on token IDs >= 0).

    Returns: tuple containing two elements:
        shift_logits: a float Tensor of shape [batch_size, seq_len - 1].
        shift_labels: an integer Tensor of shape [batch_size, seq_len - 1]
    """

    labels = encodings["input_ids"].clone()

    # convert padding and EOC tokens to -100 so they are ignored in loss
    labels[labels == tokenizer.pad_token_id] = -100
    labels[labels == eoc_token_id] = -100

    # Convert all tokens in prefix until separator to -100 so they are
    # ignored in loss
    for idx in range(len(labels)):
        # Find the location of the last token of prefix *from right*,
        # since the first non-padding token of the sequence will also be
        # eos_token (because bos_token and eos_token are the same for
        # the tokenizer).
        end_of_prefix = -labels[idx].tolist()[::-1].index(tokenizer.eos_token_id) - 1
        labels[idx, : end_of_prefix + 1] = -100

    # Shift so that tokens < n predict n. The shifted tensors both have
    # shape [batch_size, seq_len - 1].
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    return shift_logits, shift_labels


def compute_per_sample_probs(
    encodings, tokenizer, logits: torch.Tensor, eoc_token_id
) -> torch.Tensor:
    """Helper function to compute per-sample probability of the input sequence.

    Assumes <eos token> is used to separate inputs from targets in the
    prompt text
    """
    shift_logits, shift_labels = compute_shifted_logits_and_labels(
        logits, encodings, tokenizer, eoc_token_id
    )

    # Tuple of tensors for unmasked label tokens. The first element of the
    # tuple contains the batch indices; the second element contains the
    # sequence indices.
    unmasked_indices = torch.nonzero(shift_labels != -100, as_tuple=True)
    # Tensor where the i^th element is the token_id corresponding to the i^th
    # element of unmasked_indices
    unmasked_token_ids = shift_labels[unmasked_indices]

    # 3d tensor of [batch_idx, sequence_position, token_id] for unmasked tokens.
    target_idxs = torch.column_stack([*unmasked_indices, unmasked_token_ids])
    target_idxs = target_idxs.to(shift_logits.device)

    # Sanity check that every element in batch has at least one unmasked
    # target token
    assert torch.all(
        torch.bincount(target_idxs[:, 0]) != 0
    ), "At least one element in batch has no unmasked target tokens."

    # Renormalize over tokens to make sure they are proper probabilities via
    # softmax over the token dimension.
    shift_probs = torch.nn.functional.softmax(shift_logits, 2)

    # Compute the probability of the target sequence (as the product of the
    # probability of the individual tokens in the sequence).
    target_probs = torch.ones(len(shift_labels), device=shift_logits.device)
    for i, j, k in target_idxs:
        target_probs[i] *= shift_probs[i, j, k]

    return target_probs


def compute_per_sample_loss(encodings, tokenizer, logits, eoc_token_id) -> torch.Tensor:
    """Helper function to compute per-sample classification loss.

    Assumes <eos token> is used to separate inputs from targets in the
    prompt text
    """
    shift_logits, shift_labels = compute_shifted_logits_and_labels(
        logits, encodings, tokenizer, eoc_token_id
    )

    device = shift_logits.device

    # Loss is computed token-wise, on Tensors of shape
    # [batch_size * (seq_len - 1), vocab_size]
    # and returns a loss tensor of shape
    # [batch_size * (seq_len - 1)]. Most of the tokens will be masked
    # in this computation.
    loss = torch.nn.functional.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1).to(device),
        reduction="none",
    )

    # Reshape to [batch_size, seq_len - 1]
    loss = loss.view(shift_logits.size(0), shift_logits.size(1)).cpu()

    # loss_mask is 1 for tokens we want included in the loss, and 0 for tokens
    # that should be ignored in the loss.
    loss_mask = (shift_labels != -100).int().cpu()

    loss *= loss_mask

    # Compute per-element loss : sum loss over all (unmasked) tokens and
    # divide by number of variable tokens to obtain tensor of
    # shape [batch_size,]
    loss = loss.sum(dim=1) / (shift_labels != -100).sum(dim=1).float()
    return loss