fixed output generation error
Browse files- ChitChat/common/utils.py +8 -7
ChitChat/common/utils.py
CHANGED
@@ -43,13 +43,14 @@ def conversation(user, userInput):
|
|
43 |
bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
|
44 |
# print(bot_input_ids)
|
45 |
chat_history_ids = small_model.generate(
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
53 |
)
|
54 |
# print(f"chat_history_ids : {type(chat_history_ids)}")
|
55 |
saveChatHistory(user, chat_history_ids)
|
|
|
43 |
bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
|
44 |
# print(bot_input_ids)
|
45 |
chat_history_ids = small_model.generate(
|
46 |
+
bot_input_ids,
|
47 |
+
max_length = 500,
|
48 |
+
pad_token_id = tokenizer.eos_token_id,
|
49 |
+
no_repeat_ngram_size = 3,
|
50 |
+
do_sample = True,
|
51 |
+
top_k = 100,
|
52 |
+
top_p = 0.7,
|
53 |
+
temperature = 0.8
|
54 |
)
|
55 |
# print(f"chat_history_ids : {type(chat_history_ids)}")
|
56 |
saveChatHistory(user, chat_history_ids)
|