Th3BossC commited on
Commit
79b7e57
·
1 Parent(s): 8096bbf

fixed output generation error

Browse files
Files changed (1) hide show
  1. ChitChat/common/utils.py +8 -7
ChitChat/common/utils.py CHANGED
@@ -43,13 +43,14 @@ def conversation(user, userInput):
43
  bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
44
  # print(bot_input_ids)
45
  chat_history_ids = small_model.generate(
46
- bot_input_ids,
47
- num_beams = 4,
48
- no_repeat_ngram_size = 3,
49
- temperature = 0.8,
50
- top_k = 150,
51
- top_p = 0.92,
52
- repetition_penalty = 2.1
 
53
  )
54
  # print(f"chat_history_ids : {type(chat_history_ids)}")
55
  saveChatHistory(user, chat_history_ids)
 
43
  bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
44
  # print(bot_input_ids)
45
  chat_history_ids = small_model.generate(
46
+ bot_input_ids,
47
+ max_length = 500,
48
+ pad_token_id = tokenizer.eos_token_id,
49
+ no_repeat_ngram_size = 3,
50
+ do_sample = True,
51
+ top_k = 100,
52
+ top_p = 0.7,
53
+ temperature = 0.8
54
  )
55
  # print(f"chat_history_ids : {type(chat_history_ids)}")
56
  saveChatHistory(user, chat_history_ids)