changed output generation settings
Browse files- ChitChat/common/utils.py +17 -8
ChitChat/common/utils.py
CHANGED
@@ -43,13 +43,13 @@ def conversation(user, userInput):
|
|
43 |
bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
|
44 |
# print(bot_input_ids)
|
45 |
chat_history_ids = small_model.generate(
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
)
|
54 |
# print(f"chat_history_ids : {type(chat_history_ids)}")
|
55 |
saveChatHistory(user, chat_history_ids)
|
@@ -57,5 +57,14 @@ def conversation(user, userInput):
|
|
57 |
|
58 |
def complexChat(userInput):
|
59 |
input_ids = large_tokenizer(userInput, return_tensors="pt").input_ids
|
60 |
-
outputs = large_model.generate(input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
return large_tokenizer.decode(outputs[0], skip_special_tokens = True)
|
|
|
43 |
bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
|
44 |
# print(bot_input_ids)
|
45 |
chat_history_ids = small_model.generate(
|
46 |
+
bot_input_ids,
|
47 |
+
num_beams = 4,
|
48 |
+
no_repeat_ngram_size = 3,
|
49 |
+
temperature = 0.8,
|
50 |
+
top_k = 150,
|
51 |
+
top_p = 0.92,
|
52 |
+
repetition_penalty = 2.1
|
53 |
)
|
54 |
# print(f"chat_history_ids : {type(chat_history_ids)}")
|
55 |
saveChatHistory(user, chat_history_ids)
|
|
|
57 |
|
58 |
def complexChat(userInput):
|
59 |
input_ids = large_tokenizer(userInput, return_tensors="pt").input_ids
|
60 |
+
outputs = large_model.generate(input_ids,
|
61 |
+
min_length = 20,
|
62 |
+
max_new_tokens = 600,
|
63 |
+
length_penalty = 1.6,
|
64 |
+
num_beams = 4,
|
65 |
+
no_repeat_ngram_size = 3,
|
66 |
+
temperature = 0.8,
|
67 |
+
top_k = 150,
|
68 |
+
top_p = 0.92,
|
69 |
+
repetition_penalty = 2.1)
|
70 |
return large_tokenizer.decode(outputs[0], skip_special_tokens = True)
|