kcarnold commited on
Commit
bfee54e
·
1 Parent(s): 5a64f8e

Looks like I never committed these improvements to the backend.

Browse files
Files changed (1) hide show
  1. custom_llm_inference.py +44 -97
custom_llm_inference.py CHANGED
@@ -63,37 +63,29 @@ def get_highlights_inner(model, tokenizer, doc, prompt, updated_doc, k):
63
  return highlights
64
 
65
 
66
-
67
- def get_next_token_predictions_inner(
68
- model, tokenizer, original_doc, prompt, doc_in_progress, k):
69
-
70
- tokenized_chat = get_tokenized_chat(tokenizer, prompt, original_doc)
71
- doc_in_progress_ids = tokenize_doc_in_progress(tokenizer, doc_in_progress)
72
-
73
- device = model.device
74
-
75
- joined_ids = torch.cat([tokenized_chat, doc_in_progress_ids])
76
- hypotheses = joined_ids[None].to(model.device)
77
-
78
- # For each of the k next tokens, generate most-likely next tokens and append back on until we
79
- # reach a token with a space
80
-
81
  past_key_values = DynamicCache()
82
 
83
  with torch.no_grad():
84
  model_outs_onestep = model(hypotheses, output_hidden_states=True, past_key_values=past_key_values)
85
 
86
- branch_tokens = model_outs_onestep.logits[0, -1].topk(k).indices
87
 
88
- # split the cache into k reps. We pretend we're doing a "Beam search"...
89
- past_key_values.reorder_cache(torch.zeros((k,), dtype=torch.long, device=device))
90
 
91
  # Now call the model again, passing the kv cache, so we can continue generating.
92
- # Each of the k next tokens will be considered as one sequence in a "batch".
93
  next_tokens_as_batch = branch_tokens.unsqueeze(1)
94
- assert next_tokens_as_batch.shape == (k, 1)
95
 
96
- position_id_for_final_token = joined_ids.shape[0]
97
  cache_position = torch.full((1,), position_id_for_final_token, dtype=int, device=device)
98
  with torch.no_grad():
99
  model_outs = model(
@@ -105,44 +97,52 @@ def get_next_token_predictions_inner(
105
  cache_position=cache_position
106
  )
107
 
108
- # Grab the single most likely token from each of the k sequences
109
  next_token_logits = model_outs.logits[:, -1]
110
  vocab_size = model.config.vocab_size
111
- assert next_token_logits.shape == (k, vocab_size), f"{next_token_logits.shape=}, {k=}, {vocab_size=}"
112
  most_likely_token_ids = next_token_logits.argmax(dim=-1)
113
 
114
  # Stick them at the end of the branch tokens.
115
- assert most_likely_token_ids.shape == (k,)
116
  lookahead_sequences = torch.cat([
117
  branch_tokens.unsqueeze(1),
118
  most_likely_token_ids.unsqueeze(1)
119
  ], dim=1)
120
- assert lookahead_sequences.shape == (k, 2)
 
121
 
122
- decoded_next_tokens = tokenizer.batch_decode(lookahead_sequences, skip_special_tokens=True)
123
- return decoded_next_tokens, next_token_logits
124
 
125
- def get_next_token_predictions_generate(
126
  model, tokenizer, original_doc, prompt, doc_in_progress, k):
127
 
128
  tokenized_chat = get_tokenized_chat(tokenizer, prompt, original_doc)
129
  doc_in_progress_ids = tokenize_doc_in_progress(tokenizer, doc_in_progress)
130
 
 
 
131
  joined_ids = torch.cat([tokenized_chat, doc_in_progress_ids])
132
- context_without_special_tokens = tokenizer.batch_decode(joined_ids, skip_special_tokens=True)
133
- prefix_length = len(context_without_special_tokens)
134
  hypotheses = joined_ids[None].to(model.device)
135
 
136
- generation_output = model.generate(
137
- hypotheses,
138
- return_dict_in_generate=True,
139
- output_scores=True,
140
- num_beams=5, num_beam_groups=5, max_new_tokens=10, do_sample=False, diversity_penalty=1e5, top_k=None, num_return_sequences=5)#, token_healing=True, tokenizer=tokenizer)
141
- sequences = [
142
- decoded[prefix_length:]
143
- for decoded in tokenizer.batch_decode(generation_output.sequences, skip_special_tokens=True)
144
- ]
145
- return sequences,
 
 
 
 
 
 
 
 
 
146
 
147
 
148
  def get_next_token_predictions_slow(
@@ -196,67 +196,14 @@ def get_next_token_predictions_slow(
196
 
197
 
198
  def continue_messages_inner(model, tokenizer, messages, n_branch_tokens, n_future_tokens):
 
199
  device = model.device
200
 
201
- final_message_is_assistant = messages[-1]['role'] == "assistant"
202
- print(f"final_message_is_assistant: {final_message_is_assistant}")
203
- # if final_message_is_assistant:
204
- # tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, continue_final_message=True, return_tensors="pt").to(model.device)
205
- # else:
206
- # tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
207
  tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", continue_final_message=True).to(model.device)
208
-
209
  print(tokenizer.batch_decode(tokenized_chat, skip_special_tokens=False))
210
 
211
- # This fails with
212
- # RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.
213
- # generations = model.generate(
214
- # tokenized_chat,
215
- # num_return_sequences=n_branch_tokens,
216
- # num_beam_groups=n_branch_tokens, num_beams=n_branch_tokens,
217
- # do_sample=False, max_new_tokens=n_future_tokens, diversity_penalty=1e5, top_k=None,
218
- # return_dict_in_generate=True, output_scores=True)
219
-
220
- # Instead, we'll do this in two steps:
221
- # 1. Get the next token predictions for the k most likely continuations
222
- from transformers.cache_utils import DynamicCache
223
- past_key_values = DynamicCache()
224
- with torch.no_grad():
225
- model_outs = model(
226
- tokenized_chat,
227
- past_key_values=past_key_values,
228
- output_hidden_states=True,
229
- use_cache=True,
230
- )
231
- branch_tokens = model_outs.logits[0, -1].topk(n_branch_tokens).indices
232
-
233
- hypotheses = branch_tokens.unsqueeze(1)
234
- # Branch off the k most likely continuations
235
- past_key_values.reorder_cache(torch.zeros((n_branch_tokens,), dtype=torch.long, device=device))
236
 
237
- # 2. Generate the next n_future_tokens for each branch
238
- for i in range(n_future_tokens):
239
- position_id_for_final_token = tokenized_chat.shape[0] + i
240
- cache_position = torch.full((1,), position_id_for_final_token, dtype=int, device=device)
241
- final_token_ids = hypotheses[:, -1:]
242
- with torch.no_grad():
243
- model_outs = model(
244
- final_token_ids,
245
- past_key_values=past_key_values,
246
- output_hidden_states=True,
247
- use_cache=True,
248
- cache_position=cache_position
249
- )
250
-
251
- # Grab the single most likely token from each of the k sequences
252
- next_token_logits = model_outs.logits[:, -1]
253
- vocab_size = model.config.vocab_size
254
- assert next_token_logits.shape == (n_branch_tokens, vocab_size), f"{next_token_logits.shape=}, {n_branch_tokens=}, {vocab_size=}"
255
- most_likely_token_ids = next_token_logits.argmax(dim=-1)
256
- hypotheses = torch.cat([
257
- hypotheses,
258
- most_likely_token_ids.unsqueeze(1)
259
- ], dim=1)
260
-
261
- generated_docs = tokenizer.batch_decode(hypotheses, skip_special_tokens=True)
262
  return generated_docs
 
63
  return highlights
64
 
65
 
66
+ def get_lookahead_sequences(model, tokenizer, hypotheses, n_branch_tokens, device):
67
+ """
68
+ For each of the n_branch_tokens next tokens, generate most-likely next tokens and append back on.
69
+ """
70
+ assert len(hypotheses.shape) == 2
71
+ assert hypotheses.shape[0] == 1
72
+ n_tokens_so_far = hypotheses.shape[1]
 
 
 
 
 
 
 
 
73
  past_key_values = DynamicCache()
74
 
75
  with torch.no_grad():
76
  model_outs_onestep = model(hypotheses, output_hidden_states=True, past_key_values=past_key_values)
77
 
78
+ branch_tokens = model_outs_onestep.logits[0, -1].topk(n_branch_tokens).indices
79
 
80
+ # split the cache into n_branch_tokens reps. We pretend we're doing a "Beam search"...
81
+ past_key_values.reorder_cache(torch.zeros((n_branch_tokens,), dtype=torch.long, device=device))
82
 
83
  # Now call the model again, passing the kv cache, so we can continue generating.
84
+ # Each of the n_branch_tokens next tokens will be considered as one sequence in a "batch".
85
  next_tokens_as_batch = branch_tokens.unsqueeze(1)
86
+ assert next_tokens_as_batch.shape == (n_branch_tokens, 1)
87
 
88
+ position_id_for_final_token = n_tokens_so_far
89
  cache_position = torch.full((1,), position_id_for_final_token, dtype=int, device=device)
90
  with torch.no_grad():
91
  model_outs = model(
 
97
  cache_position=cache_position
98
  )
99
 
100
+ # Grab the single most likely token from each of the n_branch_tokens sequences
101
  next_token_logits = model_outs.logits[:, -1]
102
  vocab_size = model.config.vocab_size
103
+ assert next_token_logits.shape == (n_branch_tokens, vocab_size), f"{next_token_logits.shape=}, {n_branch_tokens=}, {vocab_size=}"
104
  most_likely_token_ids = next_token_logits.argmax(dim=-1)
105
 
106
  # Stick them at the end of the branch tokens.
107
+ assert most_likely_token_ids.shape == (n_branch_tokens,)
108
  lookahead_sequences = torch.cat([
109
  branch_tokens.unsqueeze(1),
110
  most_likely_token_ids.unsqueeze(1)
111
  ], dim=1)
112
+ assert lookahead_sequences.shape == (n_branch_tokens, 2)
113
+ return lookahead_sequences, next_token_logits
114
 
 
 
115
 
116
+ def get_next_token_predictions_inner(
117
  model, tokenizer, original_doc, prompt, doc_in_progress, k):
118
 
119
  tokenized_chat = get_tokenized_chat(tokenizer, prompt, original_doc)
120
  doc_in_progress_ids = tokenize_doc_in_progress(tokenizer, doc_in_progress)
121
 
122
+ device = model.device
123
+
124
  joined_ids = torch.cat([tokenized_chat, doc_in_progress_ids])
 
 
125
  hypotheses = joined_ids[None].to(model.device)
126
 
127
+ # Alternative approach: chat templates
128
+ tokenized_chat = tokenizer.apply_chat_template([
129
+ {"role": "user", "content": f"{prompt}\n\n{original_doc}"},
130
+ {"role": "assistant", "content": doc_in_progress}
131
+ ], tokenize=True, return_tensors="pt", continue_final_message=True).to(model.device)
132
+
133
+ # Compare them
134
+ if tokenized_chat.shape == hypotheses.shape and torch.all(tokenized_chat == hypotheses):
135
+ print("Tokenized chat and hypotheses match")
136
+ else:
137
+ print("FAIL: Tokenized chat and hypotheses do not match!")
138
+ print(f"{tokenized_chat=}")
139
+ print(f"{hypotheses=}")
140
+
141
+ lookahead_sequences, next_token_logits = get_lookahead_sequences(
142
+ model, tokenizer, hypotheses, k, device)
143
+
144
+ decoded_next_tokens = tokenizer.batch_decode(lookahead_sequences, skip_special_tokens=True)
145
+ return decoded_next_tokens, next_token_logits
146
 
147
 
148
  def get_next_token_predictions_slow(
 
196
 
197
 
198
  def continue_messages_inner(model, tokenizer, messages, n_branch_tokens, n_future_tokens):
199
+ # Note: we're ignoring n_future_tokens right now since the old implementation was buggy.
200
  device = model.device
201
 
 
 
 
 
 
 
202
  tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", continue_final_message=True).to(model.device)
 
203
  print(tokenizer.batch_decode(tokenized_chat, skip_special_tokens=False))
204
 
205
+ lookahead_sequences, next_token_logits = get_lookahead_sequences(
206
+ model, tokenizer, tokenized_chat, n_branch_tokens, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
+ generated_docs = tokenizer.batch_decode(lookahead_sequences, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  return generated_docs