BeveledCube commited on
Commit
1a3bc85
·
verified ·
1 Parent(s): 207c16a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +2 -2
main.py CHANGED
@@ -47,12 +47,12 @@ def read_root(data: req):
47
  new_user_input_ids = tokenizer.encode(data.prompt + tokenizer.eos_token, return_tensors='pt')
48
 
49
  # append the new user input tokens to the chat history
50
- bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
51
 
52
  # generated a response while limiting the total chat history to 1000 tokens,
53
  chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
54
 
55
- generated_text = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
56
  answer_data = { "answer": generated_text }
57
  print("Answer:", generated_text)
58
 
 
47
  new_user_input_ids = tokenizer.encode(data.prompt + tokenizer.eos_token, return_tensors='pt')
48
 
49
  # append the new user input tokens to the chat history
50
+ bot_input_ids = torch.cat(new_user_input_ids, dim=-1) if step > 0 else new_user_input_ids
51
 
52
  # generated a response while limiting the total chat history to 1000 tokens,
53
  chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
54
 
55
+ generated_text = tokenizer.decode(chat_history_ids[:, :][0], skip_special_tokens=True)
56
  answer_data = { "answer": generated_text }
57
  print("Answer:", generated_text)
58