kcarnold commited on
Commit
22b91c8
·
1 Parent(s): 43804be

Add a quick and dirty show-next-token-logits page

Browse files
Files changed (1) hide show
  1. custom_llm.py +53 -0
custom_llm.py CHANGED
@@ -203,6 +203,59 @@ def continue_messages(request: ContinueMessagesRequest):
203
  }
204
 
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  if __name__ == "__main__":
208
  uvicorn.run(app, host="localhost", port=PORT)
 
203
  }
204
 
205
 
206
+ @app.post('/api/logprobs')
207
+ def logprobs(request: ContinueMessagesRequest):
208
+
209
+ messages = [{"role": m.role, "content": m.content} for m in request.messages]
210
+ if len(messages) == 0:
211
+ raise HTTPException(status_code=400, detail="At least one message must be provided.")
212
+ n_branch_tokens = request.n_branch_tokens
213
+ n_future_tokens = request.n_future_tokens
214
+
215
+ model = ml_models['llm']['model']
216
+ tokenizer = ml_models['llm']['tokenizer']
217
+
218
+ device = model.device
219
+
220
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", continue_final_message=True).to(model.device)
221
+
222
+ # Compute all logits
223
+ with torch.no_grad():
224
+ logits = model(tokenized_chat).logits
225
+
226
+ k = request.n_branch_tokens
227
+
228
+ # Return a list of tokens:
229
+ # {
230
+ # "token": "the",
231
+ # "logprobs": [{"the": -0.1, "a": -0.2, ...}]
232
+ # }
233
+ # logprobs are the top-k logprobs for each token, plus the chosen token in case it is not in the top-k
234
+ # The very first token will have no logprobs, since it is the beginning of the document
235
+ # The very last token will have "token" set to None, and "logprobs" will be the logprobs for the next token
236
+
237
+ all_logprobs = []
238
+ for idx in range(len(tokenized_chat[0]) + 1):
239
+ if idx == len(tokenized_chat[0]):
240
+ actual_token_id = None
241
+ token = None
242
+ else:
243
+ actual_token_id = tokenized_chat[0, idx].item()
244
+ token = tokenizer.decode(actual_token_id)
245
+
246
+ if idx == 0:
247
+ token_logprobs = []
248
+ else:
249
+ logprobs = logits[0, idx - 1].log_softmax(dim=-1)
250
+ token_ids_to_return = logprobs.topk(k).indices.cpu().numpy().tolist()
251
+ if actual_token_id is not None and actual_token_id not in token_ids_to_return:
252
+ token_ids_to_return.append(actual_token_id)
253
+ token_logprobs = {tokenizer.decode(i): logprobs[i].item() for i in token_ids_to_return}
254
+ all_logprobs.append(dict(token=token, logprobs=token_logprobs))
255
+
256
+ return {
257
+ 'logprobs': all_logprobs
258
+ }
259
 
260
  if __name__ == "__main__":
261
  uvicorn.run(app, host="localhost", port=PORT)