ysdede commited on
Commit
a928ce7
·
1 Parent(s): 012c0fa

Update default inference parameters in app.py

Browse files

- Increased max_new_tokens to 2048
- Adjusted repetition_penalty to 1.0

Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
- DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
@@ -56,11 +56,11 @@ def generate(
56
  message: str,
57
  chat_history: list[dict],
58
  system_prompt: str = "",
59
- max_new_tokens: int = 1024,
60
  temperature: float = 0.6,
61
  top_p: float = 0.9,
62
  top_k: int = 50,
63
- repetition_penalty: float = 1.2,
64
  ) -> Iterator[str]:
65
  conversation = []
66
  if system_prompt:
@@ -141,7 +141,7 @@ chat_interface = gr.ChatInterface(
141
  minimum=1.0,
142
  maximum=2.0,
143
  step=0.05,
144
- value=1.2,
145
  ),
146
  ],
147
  stop_btn=None,
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
+ DEFAULT_MAX_NEW_TOKENS = 8192
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
 
56
  message: str,
57
  chat_history: list[dict],
58
  system_prompt: str = "",
59
+ max_new_tokens: int = 2048,
60
  temperature: float = 0.6,
61
  top_p: float = 0.9,
62
  top_k: int = 50,
63
+ repetition_penalty: float = 1.0,
64
  ) -> Iterator[str]:
65
  conversation = []
66
  if system_prompt:
 
141
  minimum=1.0,
142
  maximum=2.0,
143
  step=0.05,
144
+ value=1.0,
145
  ),
146
  ],
147
  stop_btn=None,