Tuchuanhuhuhu commited on
Commit
c857ac1
·
1 Parent(s): cc9e07a

增加了一大堆参数控制

Browse files
Files changed (3) hide show
  1. ChuanhuChatbot.py +59 -6
  2. modules/base_model.py +38 -18
  3. modules/models.py +38 -12
ChuanhuChatbot.py CHANGED
@@ -159,21 +159,74 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
159
  default_btn = gr.Button("🔙 恢复默认设置")
160
 
161
  with gr.Accordion("参数", open=False):
 
 
 
 
 
 
 
 
162
  top_p_slider = gr.Slider(
163
  minimum=-0,
164
  maximum=1.0,
165
  value=1.0,
166
  step=0.05,
167
  interactive=True,
168
- label="Top-p",
169
  )
170
- temperature_slider = gr.Slider(
171
- minimum=-0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  maximum=2.0,
173
- value=1.0,
174
- step=0.1,
 
 
 
 
 
 
 
 
175
  interactive=True,
176
- label="Temperature",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  )
178
 
179
  with gr.Accordion("网络设置", open=False):
 
159
  default_btn = gr.Button("🔙 恢复默认设置")
160
 
161
  with gr.Accordion("参数", open=False):
162
+ temperature_slider = gr.Slider(
163
+ minimum=-0,
164
+ maximum=2.0,
165
+ value=1.0,
166
+ step=0.1,
167
+ interactive=True,
168
+ label="temperature",
169
+ )
170
  top_p_slider = gr.Slider(
171
  minimum=-0,
172
  maximum=1.0,
173
  value=1.0,
174
  step=0.05,
175
  interactive=True,
176
+ label="top-p",
177
  )
178
+ n_choices_slider = gr.Slider(
179
+ minimum=1,
180
+ maximum=1,
181
+ value=1,
182
+ step=1,
183
+ interactive=True,
184
+ label="n choices",
185
+ )
186
+ stop_sequence_txt = gr.Textbox(
187
+ show_label=True,
188
+ placeholder=f"在这里输入停止符,用英文逗号隔开...",
189
+ label="stop",
190
+ value="",
191
+ lines=1,
192
+ )
193
+ max_tokens_slider = gr.Slider(
194
+ minimum=1,
195
+ maximum=4096,
196
+ value=4096,
197
+ step=1,
198
+ interactive=True,
199
+ label="max tokens",
200
+ )
201
+ presence_penalty_slider = gr.Slider(
202
+ minimum=-2.0,
203
  maximum=2.0,
204
+ value=0.0,
205
+ step=0.01,
206
+ interactive=True,
207
+ label="presence penalty",
208
+ )
209
+ frequency_penalty_slider = gr.Slider(
210
+ minimum=-2.0,
211
+ maximum=2.0,
212
+ value=0.0,
213
+ step=0.01,
214
  interactive=True,
215
+ label="frequency penalty",
216
+ )
217
+ logit_bias_txt = gr.Textbox(
218
+ show_label=True,
219
+ placeholder=f"word:likelihood",
220
+ label="logit bias",
221
+ value="",
222
+ lines=1,
223
+ )
224
+ user = gr.Textbox(
225
+ show_label=True,
226
+ placeholder=f"用于定位滥用行为",
227
+ label="用户名",
228
+ value=user_name.value,
229
+ lines=1,
230
  )
231
 
232
  with gr.Accordion("网络设置", open=False):
modules/base_model.py CHANGED
@@ -41,19 +41,42 @@ class ModelType(Enum):
41
 
42
 
43
  class BaseLLMModel:
44
- def __init__(self, model_name, temperature=1.0, top_p=1.0, max_generation_token=None, system_prompt="") -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  self.history = []
46
  self.all_token_counts = []
47
  self.model_name = model_name
48
  self.model_type = ModelType.get_type(model_name)
49
  self.token_upper_limit = MODEL_TOKEN_LIMIT[model_name]
50
- self.max_generation_token = max_generation_token if max_generation_token is not None else self.token_upper_limit
51
  self.interrupted = False
52
- self.temperature = temperature
53
- self.top_p = top_p
54
  self.system_prompt = system_prompt
55
  self.api_key = None
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def get_answer_stream_iter(self):
59
  """stream predict, need to be implemented
@@ -75,15 +98,11 @@ class BaseLLMModel:
75
  """get billing infomation, inplement if needed"""
76
  return BILLING_NOT_APPLICABLE_MSG
77
 
78
-
79
  def count_token(self, user_input):
80
- """get token count from input, implement if needed
81
- """
82
  return 0
83
 
84
- def stream_next_chatbot(
85
- self, inputs, chatbot, fake_input=None, display_append=""
86
- ):
87
  def get_return_value():
88
  return chatbot, status_text
89
 
@@ -106,9 +125,7 @@ class BaseLLMModel:
106
  status_text = self.token_message()
107
  yield get_return_value()
108
 
109
- def next_chatbot_at_once(
110
- self, inputs, chatbot, fake_input=None, display_append=""
111
- ):
112
  if fake_input:
113
  chatbot.append((fake_input, ""))
114
  else:
@@ -122,7 +139,7 @@ class BaseLLMModel:
122
  if fake_input is not None:
123
  self.history[-2] = construct_user(fake_input)
124
  self.history[-1] = construct_assistant(ai_reply)
125
- chatbot[-1] = (chatbot[-1][0], ai_reply+display_append)
126
  if fake_input is not None:
127
  self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
128
  else:
@@ -277,12 +294,15 @@ class BaseLLMModel:
277
  self.history = self.history[-4:]
278
  self.all_token_counts = self.all_token_counts[-2:]
279
 
280
-
281
  max_token = self.token_upper_limit - TOKEN_OFFSET
282
 
283
  if sum(self.all_token_counts) > max_token and should_check_token_count:
284
  count = 0
285
- while sum(self.all_token_counts) > self.token_upper_limit * REDUCE_TOKEN_FACTOR and sum(self.all_token_counts) > 0:
 
 
 
 
286
  count += 1
287
  del self.all_token_counts[0]
288
  del self.history[:2]
@@ -385,7 +405,7 @@ class BaseLLMModel:
385
  msg = "删除了一组对话"
386
  return chatbot, msg
387
 
388
- def token_message(self, token_lst = None):
389
  if token_lst is None:
390
  token_lst = self.all_token_counts
391
  token_sum = 0
@@ -433,4 +453,4 @@ class BaseLLMModel:
433
  return filename, json_s["system"], json_s["chatbot"]
434
  except FileNotFoundError:
435
  logging.warning(f"{user_name} 没有找到对话历史文件,不执行任何操作")
436
- return filename, self.system_prompt, chatbot
 
41
 
42
 
43
  class BaseLLMModel:
44
+ def __init__(
45
+ self,
46
+ model_name,
47
+ system_prompt="",
48
+ temperature=1.0,
49
+ top_p=1.0,
50
+ n_choices=1,
51
+ stop=None,
52
+ max_generation_token=None,
53
+ presence_penalty=0,
54
+ frequency_penalty=0,
55
+ logit_bias=None,
56
+ user="",
57
+ ) -> None:
58
  self.history = []
59
  self.all_token_counts = []
60
  self.model_name = model_name
61
  self.model_type = ModelType.get_type(model_name)
62
  self.token_upper_limit = MODEL_TOKEN_LIMIT[model_name]
 
63
  self.interrupted = False
 
 
64
  self.system_prompt = system_prompt
65
  self.api_key = None
66
 
67
+ self.temperature = temperature
68
+ self.top_p = top_p
69
+ self.n_choices = n_choices
70
+ self.stop = stop
71
+ self.max_generation_token = (
72
+ max_generation_token
73
+ if max_generation_token is not None
74
+ else self.token_upper_limit
75
+ )
76
+ self.presence_penalty = presence_penalty
77
+ self.frequency_penalty = frequency_penalty
78
+ self.logit_bias = logit_bias
79
+ self.user = user
80
 
81
  def get_answer_stream_iter(self):
82
  """stream predict, need to be implemented
 
98
  """get billing infomation, inplement if needed"""
99
  return BILLING_NOT_APPLICABLE_MSG
100
 
 
101
  def count_token(self, user_input):
102
+ """get token count from input, implement if needed"""
 
103
  return 0
104
 
105
+ def stream_next_chatbot(self, inputs, chatbot, fake_input=None, display_append=""):
 
 
106
  def get_return_value():
107
  return chatbot, status_text
108
 
 
125
  status_text = self.token_message()
126
  yield get_return_value()
127
 
128
+ def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
 
 
129
  if fake_input:
130
  chatbot.append((fake_input, ""))
131
  else:
 
139
  if fake_input is not None:
140
  self.history[-2] = construct_user(fake_input)
141
  self.history[-1] = construct_assistant(ai_reply)
142
+ chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
143
  if fake_input is not None:
144
  self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
145
  else:
 
294
  self.history = self.history[-4:]
295
  self.all_token_counts = self.all_token_counts[-2:]
296
 
 
297
  max_token = self.token_upper_limit - TOKEN_OFFSET
298
 
299
  if sum(self.all_token_counts) > max_token and should_check_token_count:
300
  count = 0
301
+ while (
302
+ sum(self.all_token_counts)
303
+ > self.token_upper_limit * REDUCE_TOKEN_FACTOR
304
+ and sum(self.all_token_counts) > 0
305
+ ):
306
  count += 1
307
  del self.all_token_counts[0]
308
  del self.history[:2]
 
405
  msg = "删除了一组对话"
406
  return chatbot, msg
407
 
408
+ def token_message(self, token_lst=None):
409
  if token_lst is None:
410
  token_lst = self.all_token_counts
411
  token_sum = 0
 
453
  return filename, json_s["system"], json_s["chatbot"]
454
  except FileNotFoundError:
455
  logging.warning(f"{user_name} 没有找到对话历史文件,不执行任何操作")
456
+ return filename, self.system_prompt, chatbot
modules/models.py CHANGED
@@ -26,16 +26,25 @@ from .base_model import BaseLLMModel, ModelType
26
 
27
  class OpenAIClient(BaseLLMModel):
28
  def __init__(
29
- self, model_name, api_key, system_prompt=INITIAL_SYSTEM_PROMPT, temperature=1.0, top_p=1.0
 
 
 
 
 
30
  ) -> None:
31
- super().__init__(model_name=model_name, temperature=temperature, top_p=top_p, system_prompt=system_prompt)
 
 
 
 
 
32
  self.api_key = api_key
33
  self.headers = {
34
  "Content-Type": "application/json",
35
  "Authorization": f"Bearer {self.api_key}",
36
  }
37
 
38
-
39
  def get_answer_stream_iter(self):
40
  response = self._get_response(stream=True)
41
  if response is not None:
@@ -57,7 +66,9 @@ class OpenAIClient(BaseLLMModel):
57
  def count_token(self, user_input):
58
  input_token_count = count_token(construct_user(user_input))
59
  if self.system_prompt is not None and len(self.all_token_counts) == 0:
60
- system_prompt_token_count = count_token(construct_system(self.system_prompt))
 
 
61
  return input_token_count + system_prompt_token_count
62
  return input_token_count
63
 
@@ -70,18 +81,20 @@ class OpenAIClient(BaseLLMModel):
70
  try:
71
  usage_data = self._get_billing_data(usage_url)
72
  except Exception as e:
73
- logging.error(f"获取API使用情况失败:"+str(e))
74
  return f"**获取API使用情况失败**"
75
- rounded_usage = "{:.5f}".format(usage_data['total_usage']/100)
76
  return f"**本月使用金额** \u3000 ${rounded_usage}"
77
  except requests.exceptions.ConnectTimeout:
78
- status_text = STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
 
 
79
  return status_text
80
  except requests.exceptions.ReadTimeout:
81
  status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
82
  return status_text
83
  except Exception as e:
84
- logging.error(f"获取API使用情况失败:"+str(e))
85
  return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
86
 
87
  @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
@@ -110,6 +123,7 @@ class OpenAIClient(BaseLLMModel):
110
  "stream": stream,
111
  "presence_penalty": 0,
112
  "frequency_penalty": 0,
 
113
  }
114
  if stream:
115
  timeout = TIMEOUT_STREAMING
@@ -145,7 +159,9 @@ class OpenAIClient(BaseLLMModel):
145
  data = response.json()
146
  return data
147
  else:
148
- raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
 
 
149
 
150
  def _decode_chat_response(self, response):
151
  for chunk in response.iter_lines():
@@ -166,15 +182,25 @@ class OpenAIClient(BaseLLMModel):
166
  # logging.error(f"Error: {e}")
167
  continue
168
 
169
- def get_model(model_name, access_key=None, temperature=None, top_p=None, system_prompt = None) -> BaseLLMModel:
 
 
 
170
  msg = f"模型设置为了: {model_name}"
171
  logging.info(msg)
172
  model_type = ModelType.get_type(model_name)
173
  if model_type == ModelType.OpenAI:
174
- model = OpenAIClient(model_name=model_name, api_key=access_key,system_prompt=system_prompt, temperature=temperature, top_p=top_p)
 
 
 
 
 
 
175
  return model, msg
176
 
177
- if __name__=="__main__":
 
178
  with open("config.json", "r") as f:
179
  openai_api_key = cjson.load(f)["openai_api_key"]
180
  client = OpenAIClient("gpt-3.5-turbo", openai_api_key)
 
26
 
27
  class OpenAIClient(BaseLLMModel):
28
  def __init__(
29
+ self,
30
+ model_name,
31
+ api_key,
32
+ system_prompt=INITIAL_SYSTEM_PROMPT,
33
+ temperature=1.0,
34
+ top_p=1.0,
35
  ) -> None:
36
+ super().__init__(
37
+ model_name=model_name,
38
+ temperature=temperature,
39
+ top_p=top_p,
40
+ system_prompt=system_prompt,
41
+ )
42
  self.api_key = api_key
43
  self.headers = {
44
  "Content-Type": "application/json",
45
  "Authorization": f"Bearer {self.api_key}",
46
  }
47
 
 
48
  def get_answer_stream_iter(self):
49
  response = self._get_response(stream=True)
50
  if response is not None:
 
66
  def count_token(self, user_input):
67
  input_token_count = count_token(construct_user(user_input))
68
  if self.system_prompt is not None and len(self.all_token_counts) == 0:
69
+ system_prompt_token_count = count_token(
70
+ construct_system(self.system_prompt)
71
+ )
72
  return input_token_count + system_prompt_token_count
73
  return input_token_count
74
 
 
81
  try:
82
  usage_data = self._get_billing_data(usage_url)
83
  except Exception as e:
84
+ logging.error(f"获取API使用情况失败:" + str(e))
85
  return f"**获取API使用情况失败**"
86
+ rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
87
  return f"**本月使用金额** \u3000 ${rounded_usage}"
88
  except requests.exceptions.ConnectTimeout:
89
+ status_text = (
90
+ STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
91
+ )
92
  return status_text
93
  except requests.exceptions.ReadTimeout:
94
  status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
95
  return status_text
96
  except Exception as e:
97
+ logging.error(f"获取API使用情况失败:" + str(e))
98
  return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
99
 
100
  @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
 
123
  "stream": stream,
124
  "presence_penalty": 0,
125
  "frequency_penalty": 0,
126
+ "max_tokens": self.max_generation_token,
127
  }
128
  if stream:
129
  timeout = TIMEOUT_STREAMING
 
159
  data = response.json()
160
  return data
161
  else:
162
+ raise Exception(
163
+ f"API request failed with status code {response.status_code}: {response.text}"
164
+ )
165
 
166
  def _decode_chat_response(self, response):
167
  for chunk in response.iter_lines():
 
182
  # logging.error(f"Error: {e}")
183
  continue
184
 
185
+
186
+ def get_model(
187
+ model_name, access_key=None, temperature=None, top_p=None, system_prompt=None
188
+ ) -> BaseLLMModel:
189
  msg = f"模型设置为了: {model_name}"
190
  logging.info(msg)
191
  model_type = ModelType.get_type(model_name)
192
  if model_type == ModelType.OpenAI:
193
+ model = OpenAIClient(
194
+ model_name=model_name,
195
+ api_key=access_key,
196
+ system_prompt=system_prompt,
197
+ temperature=temperature,
198
+ top_p=top_p,
199
+ )
200
  return model, msg
201
 
202
+
203
+ if __name__ == "__main__":
204
  with open("config.json", "r") as f:
205
  openai_api_key = cjson.load(f)["openai_api_key"]
206
  client = OpenAIClient("gpt-3.5-turbo", openai_api_key)