Tuchuanhuhuhu commited on
Commit
0127941
·
1 Parent(s): 9a2b13d

加入了运行时切换模型的功能

Browse files
Files changed (3) hide show
  1. ChuanhuChatbot.py +5 -5
  2. modules/base_model.py +7 -8
  3. modules/models.py +92 -6
ChuanhuChatbot.py CHANGED
@@ -10,7 +10,7 @@ from modules.config import *
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
- from modules.models import get_model
14
 
15
  gr.Chatbot.postprocess = postprocess
16
  PromptHelper.compact_text_chunks = compact_text_chunks
@@ -22,7 +22,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
22
  user_name = gr.State("")
23
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
24
  user_question = gr.State("")
25
- current_model = gr.State(get_model(MODELS[DEFAULT_MODEL], my_api_key)[0])
26
 
27
  topic = gr.State("未命名对话历史记录")
28
 
@@ -197,7 +197,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
197
  interactive=True,
198
  label="max context",
199
  )
200
- max_tokens_slider = gr.Slider(
201
  minimum=1,
202
  maximum=32768,
203
  value=1000,
@@ -350,7 +350,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
350
  # LLM Models
351
  keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
352
  keyTxt.submit(**get_usage_args)
353
- model_select_dropdown.change(get_model, [model_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [current_model, status_display], show_progress=True)
354
 
355
  # Template
356
  systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
@@ -392,7 +392,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
392
  top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
393
  n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
394
  stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
395
- max_tokens_slider.change(current_model.value.set_max_tokens, [max_tokens_slider], None)
396
  presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
397
  frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
398
  logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
 
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
+ from modules.models import ModelManager
14
 
15
  gr.Chatbot.postprocess = postprocess
16
  PromptHelper.compact_text_chunks = compact_text_chunks
 
22
  user_name = gr.State("")
23
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
24
  user_question = gr.State("")
25
+ current_model = gr.State(ModelManager(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key))
26
 
27
  topic = gr.State("未命名对话历史记录")
28
 
 
197
  interactive=True,
198
  label="max context",
199
  )
200
+ max_generation_slider = gr.Slider(
201
  minimum=1,
202
  maximum=32768,
203
  value=1000,
 
350
  # LLM Models
351
  keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
352
  keyTxt.submit(**get_usage_args)
353
+ model_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
354
 
355
  # Template
356
  systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
 
392
  top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
393
  n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
394
  stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
395
+ max_generation_slider.change(current_model.value.set_max_tokens, [max_generation_slider], None)
396
  presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
397
  frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
398
  logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
modules/base_model.py CHANGED
@@ -126,13 +126,15 @@ class BaseLLMModel:
126
 
127
  stream_iter = self.get_answer_stream_iter()
128
 
129
- self.history.append(construct_assistant(""))
130
  for partial_text in stream_iter:
131
- self.history[-1] = construct_assistant(partial_text)
132
  chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
133
  self.all_token_counts[-1] += 1
134
  status_text = self.token_message()
135
  yield get_return_value()
 
 
 
 
136
 
137
  def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
138
  if fake_input:
@@ -277,9 +279,6 @@ class BaseLLMModel:
277
  )
278
  for chatbot, status_text in iter:
279
  yield chatbot, status_text
280
- if self.interrupted:
281
- self.recover()
282
- break
283
  else:
284
  logging.debug("不使用流式传输")
285
  chatbot, status_text = self.next_chatbot_at_once(
@@ -326,13 +325,13 @@ class BaseLLMModel:
326
  files=None,
327
  reply_language="中文",
328
  ):
329
- logging.info("重试中……")
330
  if len(self.history) == 0:
331
  yield chatbot, f"{STANDARD_ERROR_MSG}上下文是空的"
332
  return
333
 
 
334
  del self.history[-2:]
335
- inputs = chatbot[-1][0]
336
  self.all_token_counts.pop()
337
  iter = self.predict(
338
  inputs,
@@ -344,7 +343,7 @@ class BaseLLMModel:
344
  )
345
  for x in iter:
346
  yield x
347
- logging.info("重试完毕")
348
 
349
  # def reduce_token_size(self, chatbot):
350
  # logging.info("开始减少token数量……")
 
126
 
127
  stream_iter = self.get_answer_stream_iter()
128
 
 
129
  for partial_text in stream_iter:
 
130
  chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
131
  self.all_token_counts[-1] += 1
132
  status_text = self.token_message()
133
  yield get_return_value()
134
+ if self.interrupted:
135
+ self.recover()
136
+ break
137
+ self.history.append(construct_assistant(partial_text))
138
 
139
  def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
140
  if fake_input:
 
279
  )
280
  for chatbot, status_text in iter:
281
  yield chatbot, status_text
 
 
 
282
  else:
283
  logging.debug("不使用流式传输")
284
  chatbot, status_text = self.next_chatbot_at_once(
 
325
  files=None,
326
  reply_language="中文",
327
  ):
328
+ logging.debug("重试中……")
329
  if len(self.history) == 0:
330
  yield chatbot, f"{STANDARD_ERROR_MSG}上下文是空的"
331
  return
332
 
333
+ inputs = self.history[-2]["content"]
334
  del self.history[-2:]
 
335
  self.all_token_counts.pop()
336
  iter = self.predict(
337
  inputs,
 
343
  )
344
  for x in iter:
345
  yield x
346
+ logging.debug("重试完毕")
347
 
348
  # def reduce_token_size(self, chatbot):
349
  # logging.info("开始减少token数量……")
modules/models.py CHANGED
@@ -247,7 +247,7 @@ class ChatGLM_Client(BaseLLMModel):
247
  def _get_glm_style_input(self):
248
  history = [x["content"] for x in self.history]
249
  query = history.pop()
250
- logging.info(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
251
  assert (
252
  len(history) % 2 == 0
253
  ), f"History should be even length. current history is: {history}"
@@ -365,11 +365,12 @@ class LLaMA_Client(BaseLLMModel):
365
 
366
  class ModelManager:
367
  def __init__(self, **kwargs) -> None:
368
- self.model, self.msg = self.get_model(**kwargs)
369
 
370
  def get_model(
371
  self,
372
  model_name,
 
373
  access_key=None,
374
  temperature=None,
375
  top_p=None,
@@ -378,7 +379,6 @@ class ModelManager:
378
  msg = f"模型设置为了: {model_name}"
379
  logging.info(msg)
380
  model_type = ModelType.get_type(model_name)
381
- print(model_type.name)
382
  if model_type == ModelType.OpenAI:
383
  model = OpenAIClient(
384
  model_name=model_name,
@@ -389,7 +389,93 @@ class ModelManager:
389
  )
390
  elif model_type == ModelType.ChatGLM:
391
  model = ChatGLM_Client(model_name)
392
- return model, msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
 
395
  if __name__ == "__main__":
@@ -397,8 +483,8 @@ if __name__ == "__main__":
397
  openai_api_key = cjson.load(f)["openai_api_key"]
398
  # set logging level to debug
399
  logging.basicConfig(level=logging.DEBUG)
400
- # client, _ = get_model("gpt-3.5-turbo", openai_api_key)
401
- client, _ = get_model("chatglm-6b-int4")
402
  chatbot = []
403
  stream = False
404
  # 测试账单功能
 
247
  def _get_glm_style_input(self):
248
  history = [x["content"] for x in self.history]
249
  query = history.pop()
250
+ logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
251
  assert (
252
  len(history) % 2 == 0
253
  ), f"History should be even length. current history is: {history}"
 
365
 
366
  class ModelManager:
367
  def __init__(self, **kwargs) -> None:
368
+ self.get_model(**kwargs)
369
 
370
  def get_model(
371
  self,
372
  model_name,
373
+ lora_model_path=None,
374
  access_key=None,
375
  temperature=None,
376
  top_p=None,
 
379
  msg = f"模型设置为了: {model_name}"
380
  logging.info(msg)
381
  model_type = ModelType.get_type(model_name)
 
382
  if model_type == ModelType.OpenAI:
383
  model = OpenAIClient(
384
  model_name=model_name,
 
389
  )
390
  elif model_type == ModelType.ChatGLM:
391
  model = ChatGLM_Client(model_name)
392
+ self.model = model
393
+ return msg
394
+
395
+ def predict(self, *args):
396
+ iter = self.model.predict(*args)
397
+ for i in iter:
398
+ yield i
399
+
400
+ def billing_info(self):
401
+ return self.model.billing_info()
402
+
403
+ def set_key(self, *args):
404
+ return self.model.set_key(*args)
405
+
406
+ def load_chat_history(self, *args):
407
+ return self.model.load_chat_history(*args)
408
+
409
+ def interrupt(self, *args):
410
+ return self.model.interrupt(*args)
411
+
412
+ def reset(self, *args):
413
+ return self.model.reset(*args)
414
+
415
+ def retry(self, *args):
416
+ iter = self.model.retry(*args)
417
+ for i in iter:
418
+ yield i
419
+
420
+ def delete_first_conversation(self, *args):
421
+ return self.model.delete_first_conversation(*args)
422
+
423
+ def delete_last_conversation(self, *args):
424
+ return self.model.delete_last_conversation(*args)
425
+
426
+ def set_system_prompt(self, *args):
427
+ return self.model.set_system_prompt(*args)
428
+
429
+ def save_chat_history(self, *args):
430
+ return self.model.save_chat_history(*args)
431
+
432
+ def export_markdown(self, *args):
433
+ return self.model.export_markdown(*args)
434
+
435
+ def load_chat_history(self, *args):
436
+ return self.model.load_chat_history(*args)
437
+
438
+ def set_token_upper_limit(self, *args):
439
+ return self.model.set_token_upper_limit(*args)
440
+
441
+ # temperature_slider.change(current_model.value.set_temperature, [temperature_slider], None)
442
+ # top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
443
+ # n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
444
+ # stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
445
+ # max_tokens_slider.change(current_model.value.set_max_tokens, [max_tokens_slider], None)
446
+ # presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
447
+ # frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
448
+ # logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
449
+ # user_identifier_txt.change(current_model.value.set_user_identifier, [user_identifier_txt], None)
450
+
451
+ def set_temperature(self, *args):
452
+ self.model.set_temperature(*args)
453
+
454
+ def set_top_p(self, *args):
455
+ self.model.set_top_p(*args)
456
+
457
+ def set_n_choices(self, *args):
458
+ self.model.set_n_choices(*args)
459
+
460
+ def set_stop_sequence(self, *args):
461
+ self.model.set_stop_sequence(*args)
462
+
463
+ def set_max_tokens(self, *args):
464
+ self.model.set_max_tokens(*args)
465
+
466
+ def set_presence_penalty(self, *args):
467
+ self.model.set_presence_penalty(*args)
468
+
469
+ def set_frequency_penalty(self, *args):
470
+ self.model.set_frequency_penalty(*args)
471
+
472
+ def set_logit_bias(self, *args):
473
+ self.model.set_logit_bias(*args)
474
+
475
+ def set_user_identifier(self, *args):
476
+ self.model.set_user_identifier(*args)
477
+
478
+
479
 
480
 
481
  if __name__ == "__main__":
 
483
  openai_api_key = cjson.load(f)["openai_api_key"]
484
  # set logging level to debug
485
  logging.basicConfig(level=logging.DEBUG)
486
+ # client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
487
+ client = ModelManager(model_name="chatglm-6b-int4")
488
  chatbot = []
489
  stream = False
490
  # 测试账单功能