kcarnold commited on
Commit
3c6b0e4
·
1 Parent(s): 1fa2852

Work around the empty-message bug.

Browse files
Files changed (1) hide show
  1. custom_llm.py +8 -4
custom_llm.py CHANGED
@@ -209,15 +209,19 @@ def logprobs(request: ContinueMessagesRequest):
209
  messages = [{"role": m.role, "content": m.content} for m in request.messages]
210
  if len(messages) == 0:
211
  raise HTTPException(status_code=400, detail="At least one message must be provided.")
212
- n_branch_tokens = request.n_branch_tokens
213
- n_future_tokens = request.n_future_tokens
214
 
215
  model = ml_models['llm']['model']
216
  tokenizer = ml_models['llm']['tokenizer']
217
 
218
- device = model.device
219
-
 
 
 
220
  tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", continue_final_message=True).to(model.device)
 
 
 
221
 
222
  # Compute all logits
223
  with torch.no_grad():
 
209
  messages = [{"role": m.role, "content": m.content} for m in request.messages]
210
  if len(messages) == 0:
211
  raise HTTPException(status_code=400, detail="At least one message must be provided.")
 
 
212
 
213
  model = ml_models['llm']['model']
214
  tokenizer = ml_models['llm']['tokenizer']
215
 
216
+ # Work around a bug when the last message is empty
217
+ trim_last_message = False
218
+ if messages[-1]['content'] == '':
219
+ messages[-1]['content'] = '.'
220
+ trim_last_message = True
221
  tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", continue_final_message=True).to(model.device)
222
+ if trim_last_message:
223
+ tokenized_chat = tokenized_chat[:, :-1]
224
+
225
 
226
  # Compute all logits
227
  with torch.no_grad():