File size: 6,213 Bytes
612d32b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import jittor as jt

def generate(moss, input_str, tokenizer, method, **kwargs):
    """
    Choose different methods to generate sentences.

    :param input_str: The input text.
    :param tokenizer: Tokenizer.
    :param method: Generation method. Should be one of: ['greedy', 'sample']
    :param kwargs: Other parameters used for generation.
        - max_gen_len: int. Maximum generate length. Used in all methods.
        - temperature: float. Used in ``sample``.
        - top_p: float. Used in ``sample``.
        - top_k: int. Used in ``sample``.
    """
    if method == "greedy":
        return greedy_search(moss, input_str, tokenizer, **kwargs)
    elif method == "sample":
        return sample(moss, input_str, tokenizer, **kwargs)
    else:
        raise NotImplementedError(
            f"Unsupported generation method {method}"
        )

def greedy_search(model, input_str, tokenizer, max_gen_len,
                  eos_token_id=None, pad_token_id=None):
    model.eval()
    if eos_token_id is None:
        eos_token_id = tokenizer.eos_token_id
    if pad_token_id is None and eos_token_id is not None:
        pad_token_id = eos_token_id
    eos_token_id_tensor = jt.Var(eos_token_id)

    tokenized = tokenizer(input_str, return_tensors='np')
    sentence_ids = jt.Var(tokenized['input_ids'])
    attention_mask = jt.Var(tokenized['attention_mask'])
    unfinished_sequences = sentence_ids.new(sentence_ids.shape[0]).fill_(1)
    past_key_values = None
    while True:
        # set input
        if past_key_values:
            input_ids = sentence_ids[:, -1].unsqueeze(-1)
        else:
            input_ids = sentence_ids
            
        outputs = model(input_ids, past_key_values=past_key_values,
                        attention_mask=attention_mask)
        # caculate probs
        next_token_logits = outputs['logits'][:, -1, :].float()
        next_tokens = jt.argmax(next_token_logits, dim=-1)[0]

        # concat sentence
        next_tokens = next_tokens * unfinished_sequences + \
            pad_token_id * (1 - unfinished_sequences)
        sentence_ids = jt.cat([sentence_ids, next_tokens[:, None]], dim=-1)
        # update input
        past_key_values = outputs['past_key_values']
        attention_mask = jt.cat(
            [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)

        # if eos_token was found in one sentence, set sentence to finished
        next_tokens.repeat(eos_token_id_tensor.shape[0], 1)
        unfinished_sequences = unfinished_sequences.mul(
            next_tokens.repeat(eos_token_id_tensor.shape[0], 1) \
                       .not_equal(eos_token_id_tensor.unsqueeze(1)) \
                       .prod(dim=0)
        )

        jt.sync_all()

        if unfinished_sequences.max() == 0 or sentence_ids.shape[-1] >= max_gen_len:
            break

    return sentence_ids.reshape([-1,]).tolist()[tokenized['input_ids'].shape[1]:]

def sample(model, input_str, tokenizer, max_gen_len, temperature, top_p, top_k,
           eos_token_id=None, pad_token_id=None):
    model.eval()
    if eos_token_id is None:
        eos_token_id = tokenizer.eos_token_id
    if pad_token_id is None and eos_token_id is not None:
        pad_token_id = eos_token_id
    eos_token_id_tensor = jt.Var(eos_token_id)

    tokenized = tokenizer(input_str, return_tensors='np')
    sentence_ids = jt.Var(tokenized['input_ids'])
    attention_mask = jt.Var(tokenized['attention_mask'])
    unfinished_sequences = sentence_ids.new(sentence_ids.shape[0]).fill_(1)
    past_key_values = None

    while True:

        # set input
        if past_key_values:
            input_ids = sentence_ids[:, -1].unsqueeze(-1)
        else:
            input_ids = sentence_ids
        outputs = model(input_ids, past_key_values=past_key_values,
                        attention_mask=attention_mask)

        next_token_logits = outputs['logits'][:, -1, :].float()

        # sample
        # temperature
        scores = next_token_logits / temperature
        # top_k
        scores = sample_top_k(scores, top_k)
        # top_p
        scores = sample_top_p(scores, top_p)

        probs = jt.nn.softmax(scores, dim=-1)
        next_tokens = jt.multinomial(probs, num_samples=1).squeeze(1)
        # concat sentence
        next_tokens = next_tokens * unfinished_sequences + \
            pad_token_id * (1 - unfinished_sequences)

        # update generated ids, model inputs, and length for next step
        sentence_ids = jt.cat([sentence_ids, next_tokens[:, None]], dim=-1)
        past_key_values = outputs['past_key_values']
        attention_mask = jt.cat(
            [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)

        # if eos_token was found in one sentence, set sentence to finished
        next_tokens.repeat(eos_token_id_tensor.shape[0], 1)
        unfinished_sequences = unfinished_sequences.mul(
            next_tokens.repeat(eos_token_id_tensor.shape[0], 1) \
                       .not_equal(eos_token_id_tensor.unsqueeze(1)) \
                       .prod(dim=0)
        )

        jt.sync_all()

        if unfinished_sequences.max() == 0 or sentence_ids.shape[-1] >= max_gen_len:
            break

    return sentence_ids.reshape([-1,]).tolist()[tokenized['input_ids'].shape[1]:]

def sample_top_k(scores, top_k):
    top_k = min(top_k, scores.size(-1))  # Safety check
    # Remove all tokens with a probability less than the last token of the top-k
    indices_to_remove = scores < jt.topk(scores, top_k)[0][..., -1, None]
    scores = scores.masked_fill(indices_to_remove, -float("Inf"))

    return scores

def sample_top_p(scores, top_p):
    sorted_logits, sorted_indices = jt.sort(scores, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)

    # scatter sorted tensors to original indexing
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    scores = scores.masked_fill(indices_to_remove, -float("Inf"))
    
    return scores