Spaces:
Running
Running
Add a quick and dirty show-next-token-logits page
Browse files- custom_llm.py +53 -0
custom_llm.py
CHANGED
@@ -203,6 +203,59 @@ def continue_messages(request: ContinueMessagesRequest):
|
|
203 |
}
|
204 |
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
if __name__ == "__main__":
|
208 |
uvicorn.run(app, host="localhost", port=PORT)
|
|
|
203 |
}
|
204 |
|
205 |
|
206 |
+
@app.post('/api/logprobs')
|
207 |
+
def logprobs(request: ContinueMessagesRequest):
|
208 |
+
|
209 |
+
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
210 |
+
if len(messages) == 0:
|
211 |
+
raise HTTPException(status_code=400, detail="At least one message must be provided.")
|
212 |
+
n_branch_tokens = request.n_branch_tokens
|
213 |
+
n_future_tokens = request.n_future_tokens
|
214 |
+
|
215 |
+
model = ml_models['llm']['model']
|
216 |
+
tokenizer = ml_models['llm']['tokenizer']
|
217 |
+
|
218 |
+
device = model.device
|
219 |
+
|
220 |
+
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", continue_final_message=True).to(model.device)
|
221 |
+
|
222 |
+
# Compute all logits
|
223 |
+
with torch.no_grad():
|
224 |
+
logits = model(tokenized_chat).logits
|
225 |
+
|
226 |
+
k = request.n_branch_tokens
|
227 |
+
|
228 |
+
# Return a list of tokens:
|
229 |
+
# {
|
230 |
+
# "token": "the",
|
231 |
+
# "logprobs": [{"the": -0.1, "a": -0.2, ...}]
|
232 |
+
# }
|
233 |
+
# logprobs are the top-k logprobs for each token, plus the chosen token in case it is not in the top-k
|
234 |
+
# The very first token will have no logprobs, since it is the beginning of the document
|
235 |
+
# The very last token will have "token" set to None, and "logprobs" will be the logprobs for the next token
|
236 |
+
|
237 |
+
all_logprobs = []
|
238 |
+
for idx in range(len(tokenized_chat[0]) + 1):
|
239 |
+
if idx == len(tokenized_chat[0]):
|
240 |
+
actual_token_id = None
|
241 |
+
token = None
|
242 |
+
else:
|
243 |
+
actual_token_id = tokenized_chat[0, idx].item()
|
244 |
+
token = tokenizer.decode(actual_token_id)
|
245 |
+
|
246 |
+
if idx == 0:
|
247 |
+
token_logprobs = []
|
248 |
+
else:
|
249 |
+
logprobs = logits[0, idx - 1].log_softmax(dim=-1)
|
250 |
+
token_ids_to_return = logprobs.topk(k).indices.cpu().numpy().tolist()
|
251 |
+
if actual_token_id is not None and actual_token_id not in token_ids_to_return:
|
252 |
+
token_ids_to_return.append(actual_token_id)
|
253 |
+
token_logprobs = {tokenizer.decode(i): logprobs[i].item() for i in token_ids_to_return}
|
254 |
+
all_logprobs.append(dict(token=token, logprobs=token_logprobs))
|
255 |
+
|
256 |
+
return {
|
257 |
+
'logprobs': all_logprobs
|
258 |
+
}
|
259 |
|
260 |
if __name__ == "__main__":
|
261 |
uvicorn.run(app, host="localhost", port=PORT)
|