File size: 6,213 Bytes
612d32b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import jittor as jt
def generate(moss, input_str, tokenizer, method, **kwargs):
"""
Choose different methods to generate sentences.
:param input_str: The input text.
:param tokenizer: Tokenizer.
:param method: Generation method. Should be one of: ['greedy', 'sample']
:param kwargs: Other parameters used for generation.
- max_gen_len: int. Maximum generate length. Used in all methods.
- temperature: float. Used in ``sample``.
- top_p: float. Used in ``sample``.
- top_k: int. Used in ``sample``.
"""
if method == "greedy":
return greedy_search(moss, input_str, tokenizer, **kwargs)
elif method == "sample":
return sample(moss, input_str, tokenizer, **kwargs)
else:
raise NotImplementedError(
f"Unsupported generation method {method}"
)
def greedy_search(model, input_str, tokenizer, max_gen_len,
eos_token_id=None, pad_token_id=None):
model.eval()
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
if pad_token_id is None and eos_token_id is not None:
pad_token_id = eos_token_id
eos_token_id_tensor = jt.Var(eos_token_id)
tokenized = tokenizer(input_str, return_tensors='np')
sentence_ids = jt.Var(tokenized['input_ids'])
attention_mask = jt.Var(tokenized['attention_mask'])
unfinished_sequences = sentence_ids.new(sentence_ids.shape[0]).fill_(1)
past_key_values = None
while True:
# set input
if past_key_values:
input_ids = sentence_ids[:, -1].unsqueeze(-1)
else:
input_ids = sentence_ids
outputs = model(input_ids, past_key_values=past_key_values,
attention_mask=attention_mask)
# caculate probs
next_token_logits = outputs['logits'][:, -1, :].float()
next_tokens = jt.argmax(next_token_logits, dim=-1)[0]
# concat sentence
next_tokens = next_tokens * unfinished_sequences + \
pad_token_id * (1 - unfinished_sequences)
sentence_ids = jt.cat([sentence_ids, next_tokens[:, None]], dim=-1)
# update input
past_key_values = outputs['past_key_values']
attention_mask = jt.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
# if eos_token was found in one sentence, set sentence to finished
next_tokens.repeat(eos_token_id_tensor.shape[0], 1)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.repeat(eos_token_id_tensor.shape[0], 1) \
.not_equal(eos_token_id_tensor.unsqueeze(1)) \
.prod(dim=0)
)
jt.sync_all()
if unfinished_sequences.max() == 0 or sentence_ids.shape[-1] >= max_gen_len:
break
return sentence_ids.reshape([-1,]).tolist()[tokenized['input_ids'].shape[1]:]
def sample(model, input_str, tokenizer, max_gen_len, temperature, top_p, top_k,
eos_token_id=None, pad_token_id=None):
model.eval()
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
if pad_token_id is None and eos_token_id is not None:
pad_token_id = eos_token_id
eos_token_id_tensor = jt.Var(eos_token_id)
tokenized = tokenizer(input_str, return_tensors='np')
sentence_ids = jt.Var(tokenized['input_ids'])
attention_mask = jt.Var(tokenized['attention_mask'])
unfinished_sequences = sentence_ids.new(sentence_ids.shape[0]).fill_(1)
past_key_values = None
while True:
# set input
if past_key_values:
input_ids = sentence_ids[:, -1].unsqueeze(-1)
else:
input_ids = sentence_ids
outputs = model(input_ids, past_key_values=past_key_values,
attention_mask=attention_mask)
next_token_logits = outputs['logits'][:, -1, :].float()
# sample
# temperature
scores = next_token_logits / temperature
# top_k
scores = sample_top_k(scores, top_k)
# top_p
scores = sample_top_p(scores, top_p)
probs = jt.nn.softmax(scores, dim=-1)
next_tokens = jt.multinomial(probs, num_samples=1).squeeze(1)
# concat sentence
next_tokens = next_tokens * unfinished_sequences + \
pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
sentence_ids = jt.cat([sentence_ids, next_tokens[:, None]], dim=-1)
past_key_values = outputs['past_key_values']
attention_mask = jt.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
# if eos_token was found in one sentence, set sentence to finished
next_tokens.repeat(eos_token_id_tensor.shape[0], 1)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.repeat(eos_token_id_tensor.shape[0], 1) \
.not_equal(eos_token_id_tensor.unsqueeze(1)) \
.prod(dim=0)
)
jt.sync_all()
if unfinished_sequences.max() == 0 or sentence_ids.shape[-1] >= max_gen_len:
break
return sentence_ids.reshape([-1,]).tolist()[tokenized['input_ids'].shape[1]:]
def sample_top_k(scores, top_k):
top_k = min(top_k, scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < jt.topk(scores, top_k)[0][..., -1, None]
scores = scores.masked_fill(indices_to_remove, -float("Inf"))
return scores
def sample_top_p(scores, top_p):
sorted_logits, sorted_indices = jt.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, -float("Inf"))
return scores
|