guest commited on
Commit
066a302
·
1 Parent(s): 0a2a419

多用户使用同一个后端model

Browse files
Files changed (2) hide show
  1. ChuanhuChatbot.py +27 -26
  2. modules/models.py +43 -8
ChuanhuChatbot.py CHANGED
@@ -15,6 +15,8 @@ from modules.models import ModelManager
15
  gr.Chatbot.postprocess = postprocess
16
  PromptHelper.compact_text_chunks = compact_text_chunks
17
 
 
 
18
  with open("assets/custom.css", "r", encoding="utf-8") as f:
19
  customCSS = f.read()
20
 
@@ -22,7 +24,6 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
22
  user_name = gr.State("")
23
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
24
  user_question = gr.State("")
25
- current_model = gr.State(ModelManager(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key))
26
 
27
  topic = gr.State("未命名对话历史记录")
28
 
@@ -264,7 +265,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
264
  gr.Markdown(CHUANHU_DESCRIPTION)
265
  gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
266
  chatgpt_predict_args = dict(
267
- fn=current_model.value.predict,
268
  inputs=[
269
  user_question,
270
  chatbot,
@@ -297,18 +298,18 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
297
  )
298
 
299
  get_usage_args = dict(
300
- fn=current_model.value.billing_info, inputs=None, outputs=[usageTxt], show_progress=False
301
  )
302
 
303
  load_history_from_file_args = dict(
304
- fn=current_model.value.load_chat_history,
305
  inputs=[historyFileSelectDropdown, chatbot, user_name],
306
  outputs=[saveFileName, systemPromptTxt, chatbot]
307
  )
308
 
309
 
310
  # Chatbot
311
- cancelBtn.click(current_model.value.interrupt, [], [])
312
 
313
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
314
  user_input.submit(**get_usage_args)
@@ -317,14 +318,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
317
  submitBtn.click(**get_usage_args)
318
 
319
  emptyBtn.click(
320
- current_model.value.reset,
321
  outputs=[chatbot, status_display],
322
  show_progress=True,
323
  )
324
  emptyBtn.click(**reset_textbox_args)
325
 
326
  retryBtn.click(**start_outputing_args).then(
327
- current_model.value.retry,
328
  [
329
  chatbot,
330
  use_streaming_checkbox,
@@ -338,13 +339,13 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
338
  retryBtn.click(**get_usage_args)
339
 
340
  delFirstBtn.click(
341
- current_model.value.delete_first_conversation,
342
  None,
343
  [status_display],
344
  )
345
 
346
  delLastBtn.click(
347
- current_model.value.delete_last_conversation,
348
  [chatbot],
349
  [chatbot, status_display],
350
  show_progress=False
@@ -353,14 +354,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
353
  two_column.change(update_doc_config, [two_column], None)
354
 
355
  # LLM Models
356
- keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
357
  keyTxt.submit(**get_usage_args)
358
- single_turn_checkbox.change(current_model.value.set_single_turn, single_turn_checkbox, None)
359
- model_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display, lora_select_dropdown], show_progress=True)
360
- lora_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
361
 
362
  # Template
363
- systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
364
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
365
  templateFileSelectDropdown.change(
366
  load_template,
@@ -377,14 +378,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
377
 
378
  # S&L
379
  saveHistoryBtn.click(
380
- current_model.value.save_chat_history,
381
  [saveFileName, chatbot, user_name],
382
  downloadFile,
383
  show_progress=True,
384
  )
385
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
386
  exportMarkdownBtn.click(
387
- current_model.value.export_markdown,
388
  [saveFileName, chatbot, user_name],
389
  downloadFile,
390
  show_progress=True,
@@ -394,16 +395,16 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
394
  downloadFile.change(**load_history_from_file_args)
395
 
396
  # Advanced
397
- max_context_length_slider.change(current_model.value.set_token_upper_limit, [max_context_length_slider], None)
398
- temperature_slider.change(current_model.value.set_temperature, [temperature_slider], None)
399
- top_p_slider.change(current_model.value.set_top_p, [top_p_slider], None)
400
- n_choices_slider.change(current_model.value.set_n_choices, [n_choices_slider], None)
401
- stop_sequence_txt.change(current_model.value.set_stop_sequence, [stop_sequence_txt], None)
402
- max_generation_slider.change(current_model.value.set_max_tokens, [max_generation_slider], None)
403
- presence_penalty_slider.change(current_model.value.set_presence_penalty, [presence_penalty_slider], None)
404
- frequency_penalty_slider.change(current_model.value.set_frequency_penalty, [frequency_penalty_slider], None)
405
- logit_bias_txt.change(current_model.value.set_logit_bias, [logit_bias_txt], None)
406
- user_identifier_txt.change(current_model.value.set_user_identifier, [user_identifier_txt], None)
407
 
408
  default_btn.click(
409
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
 
15
  gr.Chatbot.postprocess = postprocess
16
  PromptHelper.compact_text_chunks = compact_text_chunks
17
 
18
+ current_model = ModelManager(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key)
19
+
20
  with open("assets/custom.css", "r", encoding="utf-8") as f:
21
  customCSS = f.read()
22
 
 
24
  user_name = gr.State("")
25
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
26
  user_question = gr.State("")
 
27
 
28
  topic = gr.State("未命名对话历史记录")
29
 
 
265
  gr.Markdown(CHUANHU_DESCRIPTION)
266
  gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
267
  chatgpt_predict_args = dict(
268
+ fn=current_model.predict,
269
  inputs=[
270
  user_question,
271
  chatbot,
 
298
  )
299
 
300
  get_usage_args = dict(
301
+ fn=current_model.billing_info, inputs=None, outputs=[usageTxt], show_progress=False
302
  )
303
 
304
  load_history_from_file_args = dict(
305
+ fn=current_model.load_chat_history,
306
  inputs=[historyFileSelectDropdown, chatbot, user_name],
307
  outputs=[saveFileName, systemPromptTxt, chatbot]
308
  )
309
 
310
 
311
  # Chatbot
312
+ cancelBtn.click(current_model.interrupt, [], [])
313
 
314
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
315
  user_input.submit(**get_usage_args)
 
318
  submitBtn.click(**get_usage_args)
319
 
320
  emptyBtn.click(
321
+ current_model.reset,
322
  outputs=[chatbot, status_display],
323
  show_progress=True,
324
  )
325
  emptyBtn.click(**reset_textbox_args)
326
 
327
  retryBtn.click(**start_outputing_args).then(
328
+ current_model.retry,
329
  [
330
  chatbot,
331
  use_streaming_checkbox,
 
339
  retryBtn.click(**get_usage_args)
340
 
341
  delFirstBtn.click(
342
+ current_model.delete_first_conversation,
343
  None,
344
  [status_display],
345
  )
346
 
347
  delLastBtn.click(
348
+ current_model.delete_last_conversation,
349
  [chatbot],
350
  [chatbot, status_display],
351
  show_progress=False
 
354
  two_column.change(update_doc_config, [two_column], None)
355
 
356
  # LLM Models
357
+ keyTxt.change(current_model.set_key, keyTxt, [status_display]).then(**get_usage_args)
358
  keyTxt.submit(**get_usage_args)
359
+ single_turn_checkbox.change(current_model.set_single_turn, single_turn_checkbox, None)
360
+ model_select_dropdown.change(current_model.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display, lora_select_dropdown], show_progress=True)
361
+ lora_select_dropdown.change(current_model.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
362
 
363
  # Template
364
+ systemPromptTxt.change(current_model.set_system_prompt, [systemPromptTxt], None)
365
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
366
  templateFileSelectDropdown.change(
367
  load_template,
 
378
 
379
  # S&L
380
  saveHistoryBtn.click(
381
+ current_model.save_chat_history,
382
  [saveFileName, chatbot, user_name],
383
  downloadFile,
384
  show_progress=True,
385
  )
386
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
387
  exportMarkdownBtn.click(
388
+ current_model.export_markdown,
389
  [saveFileName, chatbot, user_name],
390
  downloadFile,
391
  show_progress=True,
 
395
  downloadFile.change(**load_history_from_file_args)
396
 
397
  # Advanced
398
+ max_context_length_slider.change(current_model.set_token_upper_limit, [max_context_length_slider], None)
399
+ temperature_slider.change(current_model.set_temperature, [temperature_slider], None)
400
+ top_p_slider.change(current_model.set_top_p, [top_p_slider], None)
401
+ n_choices_slider.change(current_model.set_n_choices, [n_choices_slider], None)
402
+ stop_sequence_txt.change(current_model.set_stop_sequence, [stop_sequence_txt], None)
403
+ max_generation_slider.change(current_model.set_max_tokens, [max_generation_slider], None)
404
+ presence_penalty_slider.change(current_model.set_presence_penalty, [presence_penalty_slider], None)
405
+ frequency_penalty_slider.change(current_model.set_frequency_penalty, [frequency_penalty_slider], None)
406
+ logit_bias_txt.change(current_model.set_logit_bias, [logit_bias_txt], None)
407
+ user_identifier_txt.change(current_model.set_user_identifier, [user_identifier_txt], None)
408
 
409
  default_btn.click(
410
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
modules/models.py CHANGED
@@ -195,7 +195,7 @@ class OpenAIClient(BaseLLMModel):
195
  chunk = json.loads(chunk[6:])
196
  except json.JSONDecodeError:
197
  print(f"JSON解析错误,收到的内容: {chunk}")
198
- error_msg+=chunk
199
  continue
200
  if chunk_length > 6 and "delta" in chunk["choices"][0]:
201
  if chunk["choices"][0]["finish_reason"] == "stop":
@@ -216,7 +216,7 @@ class ChatGLM_Client(BaseLLMModel):
216
  import torch
217
 
218
  system_name = platform.system()
219
- model_path=None
220
  if os.path.exists("models"):
221
  model_dirs = os.listdir("models")
222
  if model_name in model_dirs:
@@ -292,6 +292,7 @@ class LLaMA_Client(BaseLLMModel):
292
  from lmflow.pipeline.auto_pipeline import AutoPipeline
293
  from lmflow.models.auto_model import AutoModel
294
  from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
 
295
  model_path = None
296
  if os.path.exists("models"):
297
  model_dirs = os.listdir("models")
@@ -304,10 +305,33 @@ class LLaMA_Client(BaseLLMModel):
304
  # raise Exception(f"models目录下没有这个模型: {model_name}")
305
  if lora_path is not None:
306
  lora_path = f"lora/{lora_path}"
 
307
  self.max_generation_token = 1000
308
  pipeline_name = "inferencer"
309
- model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
310
- pipeline_args = InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  with open(pipeline_args.deepspeed, "r") as f:
313
  ds_config = json.load(f)
@@ -374,7 +398,7 @@ class LLaMA_Client(BaseLLMModel):
374
  step = 1
375
  for _ in range(0, self.max_generation_token, step):
376
  input_dataset = self.dataset.from_dict(
377
- {"type": "text_only", "instances": [{"text": context+partial_text}]}
378
  )
379
  output_dataset = self.inferencer.inference(
380
  model=self.model,
@@ -404,6 +428,17 @@ class ModelManager:
404
  system_prompt=None,
405
  ) -> BaseLLMModel:
406
  msg = f"模型设置为了: {model_name}"
 
 
 
 
 
 
 
 
 
 
 
407
  model_type = ModelType.get_type(model_name)
408
  lora_selector_visibility = False
409
  lora_choices = []
@@ -451,7 +486,9 @@ class ModelManager:
451
  if dont_change_lora_selector:
452
  return msg
453
  else:
454
- return msg, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility)
 
 
455
 
456
  def predict(self, *args):
457
  iter = self.model.predict(*args)
@@ -530,8 +567,6 @@ class ModelManager:
530
  self.model.set_single_turn(*args)
531
 
532
 
533
-
534
-
535
  if __name__ == "__main__":
536
  with open("config.json", "r") as f:
537
  openai_api_key = cjson.load(f)["openai_api_key"]
 
195
  chunk = json.loads(chunk[6:])
196
  except json.JSONDecodeError:
197
  print(f"JSON解析错误,收到的内容: {chunk}")
198
+ error_msg += chunk
199
  continue
200
  if chunk_length > 6 and "delta" in chunk["choices"][0]:
201
  if chunk["choices"][0]["finish_reason"] == "stop":
 
216
  import torch
217
 
218
  system_name = platform.system()
219
+ model_path = None
220
  if os.path.exists("models"):
221
  model_dirs = os.listdir("models")
222
  if model_name in model_dirs:
 
292
  from lmflow.pipeline.auto_pipeline import AutoPipeline
293
  from lmflow.models.auto_model import AutoModel
294
  from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
295
+
296
  model_path = None
297
  if os.path.exists("models"):
298
  model_dirs = os.listdir("models")
 
305
  # raise Exception(f"models目录下没有这个模型: {model_name}")
306
  if lora_path is not None:
307
  lora_path = f"lora/{lora_path}"
308
+ self.lora_path = lora_path
309
  self.max_generation_token = 1000
310
  pipeline_name = "inferencer"
311
+ model_args = ModelArguments(
312
+ model_name_or_path=model_source,
313
+ lora_model_path=lora_path,
314
+ model_type=None,
315
+ config_overrides=None,
316
+ config_name=None,
317
+ tokenizer_name=None,
318
+ cache_dir=None,
319
+ use_fast_tokenizer=True,
320
+ model_revision="main",
321
+ use_auth_token=False,
322
+ torch_dtype=None,
323
+ use_lora=False,
324
+ lora_r=8,
325
+ lora_alpha=32,
326
+ lora_dropout=0.1,
327
+ use_ram_optimized_load=True,
328
+ )
329
+ pipeline_args = InferencerArguments(
330
+ local_rank=0,
331
+ random_seed=1,
332
+ deepspeed="configs/ds_config_chatbot.json",
333
+ mixed_precision="bf16",
334
+ )
335
 
336
  with open(pipeline_args.deepspeed, "r") as f:
337
  ds_config = json.load(f)
 
398
  step = 1
399
  for _ in range(0, self.max_generation_token, step):
400
  input_dataset = self.dataset.from_dict(
401
+ {"type": "text_only", "instances": [{"text": context + partial_text}]}
402
  )
403
  output_dataset = self.inferencer.inference(
404
  model=self.model,
 
428
  system_prompt=None,
429
  ) -> BaseLLMModel:
430
  msg = f"模型设置为了: {model_name}"
431
+ if self.model is not None and model_name == self.model.model_name:
432
+ # 如果模型名字一样,那么就不用重新加载模型
433
+ # if (
434
+ # lora_model_path is not None
435
+ # and hasattr(self.model, "lora_path")
436
+ # and lora_model_path == self.model.lora_path
437
+ # or lora_model_path is None
438
+ # and not hasattr(self.model, "lora_path")
439
+ # ):
440
+ logging.info(f"模型 {model_name} 已经加载,不需要重新加载")
441
+ return msg
442
  model_type = ModelType.get_type(model_name)
443
  lora_selector_visibility = False
444
  lora_choices = []
 
486
  if dont_change_lora_selector:
487
  return msg
488
  else:
489
+ return msg, gr.Dropdown.update(
490
+ choices=lora_choices, visible=lora_selector_visibility
491
+ )
492
 
493
  def predict(self, *args):
494
  iter = self.model.predict(*args)
 
567
  self.model.set_single_turn(*args)
568
 
569
 
 
 
570
  if __name__ == "__main__":
571
  with open("config.json", "r") as f:
572
  openai_api_key = cjson.load(f)["openai_api_key"]