Th3BossC commited on
Commit
7f9c79a
·
1 Parent(s): aa19489

changed output generation settings

Browse files
Files changed (1) hide show
  1. ChitChat/common/utils.py +17 -8
ChitChat/common/utils.py CHANGED
@@ -43,13 +43,13 @@ def conversation(user, userInput):
43
  bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
44
  # print(bot_input_ids)
45
  chat_history_ids = small_model.generate(
46
- bot_input_ids,
47
- max_length = 500,
48
- no_repeat_ngram_size = 3,
49
- do_sample = True,
50
- top_k = 100,
51
- top_p = 0.7,
52
- temperature = 0.8
53
  )
54
  # print(f"chat_history_ids : {type(chat_history_ids)}")
55
  saveChatHistory(user, chat_history_ids)
@@ -57,5 +57,14 @@ def conversation(user, userInput):
57
 
58
  def complexChat(userInput):
59
  input_ids = large_tokenizer(userInput, return_tensors="pt").input_ids
60
- outputs = large_model.generate(input_ids)
 
 
 
 
 
 
 
 
 
61
  return large_tokenizer.decode(outputs[0], skip_special_tokens = True)
 
43
  bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
44
  # print(bot_input_ids)
45
  chat_history_ids = small_model.generate(
46
+ bot_input_ids,
47
+ num_beams = 4,
48
+ no_repeat_ngram_size = 3,
49
+ temperature = 0.8,
50
+ top_k = 150,
51
+ top_p = 0.92,
52
+ repetition_penalty = 2.1
53
  )
54
  # print(f"chat_history_ids : {type(chat_history_ids)}")
55
  saveChatHistory(user, chat_history_ids)
 
57
 
58
  def complexChat(userInput):
59
  input_ids = large_tokenizer(userInput, return_tensors="pt").input_ids
60
+ outputs = large_model.generate(input_ids,
61
+ min_length = 20,
62
+ max_new_tokens = 600,
63
+ length_penalty = 1.6,
64
+ num_beams = 4,
65
+ no_repeat_ngram_size = 3,
66
+ temperature = 0.8,
67
+ top_k = 150,
68
+ top_p = 0.92,
69
+ repetition_penalty = 2.1)
70
  return large_tokenizer.decode(outputs[0], skip_special_tokens = True)