MZhaovo commited on
Commit
9911cfb
·
2 Parent(s): c445249 90ce14b

将本地LLM设置为全局变量,防止多次调用;Make Class Great Again.

Browse files
Files changed (4) hide show
  1. ChuanhuChatbot.py +38 -33
  2. modules/models.py +138 -257
  3. modules/presets.py +5 -0
  4. modules/utils.py +76 -0
ChuanhuChatbot.py CHANGED
@@ -10,20 +10,22 @@ from modules.config import *
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
- from modules.models import ModelManager
14
 
15
  gr.Chatbot.postprocess = postprocess
16
  PromptHelper.compact_text_chunks = compact_text_chunks
17
 
18
- current_model = ModelManager(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key)
19
-
20
  with open("assets/custom.css", "r", encoding="utf-8") as f:
21
  customCSS = f.read()
22
 
 
 
 
23
  with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
24
  user_name = gr.State("")
25
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
26
  user_question = gr.State("")
 
27
 
28
  topic = gr.State("未命名对话历史记录")
29
 
@@ -265,8 +267,9 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
265
  gr.Markdown(CHUANHU_DESCRIPTION)
266
  gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
267
  chatgpt_predict_args = dict(
268
- fn=current_model.predict,
269
  inputs=[
 
270
  user_question,
271
  chatbot,
272
  use_streaming_checkbox,
@@ -298,18 +301,18 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
298
  )
299
 
300
  get_usage_args = dict(
301
- fn=current_model.billing_info, inputs=None, outputs=[usageTxt], show_progress=False
302
  )
303
 
304
  load_history_from_file_args = dict(
305
- fn=current_model.load_chat_history,
306
- inputs=[historyFileSelectDropdown, chatbot, user_name],
307
  outputs=[saveFileName, systemPromptTxt, chatbot]
308
  )
309
 
310
 
311
  # Chatbot
312
- cancelBtn.click(current_model.interrupt, [], [])
313
 
314
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
315
  user_input.submit(**get_usage_args)
@@ -318,15 +321,17 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
318
  submitBtn.click(**get_usage_args)
319
 
320
  emptyBtn.click(
321
- current_model.reset,
 
322
  outputs=[chatbot, status_display],
323
  show_progress=True,
324
  )
325
  emptyBtn.click(**reset_textbox_args)
326
 
327
  retryBtn.click(**start_outputing_args).then(
328
- current_model.retry,
329
  [
 
330
  chatbot,
331
  use_streaming_checkbox,
332
  use_websearch_checkbox,
@@ -339,14 +344,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
339
  retryBtn.click(**get_usage_args)
340
 
341
  delFirstBtn.click(
342
- current_model.delete_first_conversation,
343
- None,
344
  [status_display],
345
  )
346
 
347
  delLastBtn.click(
348
- current_model.delete_last_conversation,
349
- [chatbot],
350
  [chatbot, status_display],
351
  show_progress=False
352
  )
@@ -354,14 +359,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
354
  two_column.change(update_doc_config, [two_column], None)
355
 
356
  # LLM Models
357
- keyTxt.change(current_model.set_key, keyTxt, [status_display]).then(**get_usage_args)
358
  keyTxt.submit(**get_usage_args)
359
- single_turn_checkbox.change(current_model.set_single_turn, single_turn_checkbox, None)
360
- model_select_dropdown.change(current_model.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display, lora_select_dropdown], show_progress=True)
361
- lora_select_dropdown.change(current_model.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
362
 
363
  # Template
364
- systemPromptTxt.change(current_model.set_system_prompt, [systemPromptTxt], None)
365
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
366
  templateFileSelectDropdown.change(
367
  load_template,
@@ -378,15 +383,15 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
378
 
379
  # S&L
380
  saveHistoryBtn.click(
381
- current_model.save_chat_history,
382
- [saveFileName, chatbot, user_name],
383
  downloadFile,
384
  show_progress=True,
385
  )
386
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
387
  exportMarkdownBtn.click(
388
- current_model.export_markdown,
389
- [saveFileName, chatbot, user_name],
390
  downloadFile,
391
  show_progress=True,
392
  )
@@ -395,16 +400,16 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
395
  downloadFile.change(**load_history_from_file_args)
396
 
397
  # Advanced
398
- max_context_length_slider.change(current_model.set_token_upper_limit, [max_context_length_slider], None)
399
- temperature_slider.change(current_model.set_temperature, [temperature_slider], None)
400
- top_p_slider.change(current_model.set_top_p, [top_p_slider], None)
401
- n_choices_slider.change(current_model.set_n_choices, [n_choices_slider], None)
402
- stop_sequence_txt.change(current_model.set_stop_sequence, [stop_sequence_txt], None)
403
- max_generation_slider.change(current_model.set_max_tokens, [max_generation_slider], None)
404
- presence_penalty_slider.change(current_model.set_presence_penalty, [presence_penalty_slider], None)
405
- frequency_penalty_slider.change(current_model.set_frequency_penalty, [frequency_penalty_slider], None)
406
- logit_bias_txt.change(current_model.set_logit_bias, [logit_bias_txt], None)
407
- user_identifier_txt.change(current_model.set_user_identifier, [user_identifier_txt], None)
408
 
409
  default_btn.click(
410
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
 
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
+ from modules.models import get_model
14
 
15
  gr.Chatbot.postprocess = postprocess
16
  PromptHelper.compact_text_chunks = compact_text_chunks
17
 
 
 
18
  with open("assets/custom.css", "r", encoding="utf-8") as f:
19
  customCSS = f.read()
20
 
21
+ def create_new_model():
22
+ return get_model(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key)[0]
23
+
24
  with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
25
  user_name = gr.State("")
26
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
27
  user_question = gr.State("")
28
+ current_model = gr.State(create_new_model)
29
 
30
  topic = gr.State("未命名对话历史记录")
31
 
 
267
  gr.Markdown(CHUANHU_DESCRIPTION)
268
  gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
269
  chatgpt_predict_args = dict(
270
+ fn=predict,
271
  inputs=[
272
+ current_model,
273
  user_question,
274
  chatbot,
275
  use_streaming_checkbox,
 
301
  )
302
 
303
  get_usage_args = dict(
304
+ fn=billing_info, inputs=[current_model], outputs=[usageTxt], show_progress=False
305
  )
306
 
307
  load_history_from_file_args = dict(
308
+ fn=load_chat_history,
309
+ inputs=[current_model, historyFileSelectDropdown, chatbot, user_name],
310
  outputs=[saveFileName, systemPromptTxt, chatbot]
311
  )
312
 
313
 
314
  # Chatbot
315
+ cancelBtn.click(interrupt, [current_model], [])
316
 
317
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
318
  user_input.submit(**get_usage_args)
 
321
  submitBtn.click(**get_usage_args)
322
 
323
  emptyBtn.click(
324
+ reset,
325
+ inputs=[current_model],
326
  outputs=[chatbot, status_display],
327
  show_progress=True,
328
  )
329
  emptyBtn.click(**reset_textbox_args)
330
 
331
  retryBtn.click(**start_outputing_args).then(
332
+ retry,
333
  [
334
+ current_model,
335
  chatbot,
336
  use_streaming_checkbox,
337
  use_websearch_checkbox,
 
344
  retryBtn.click(**get_usage_args)
345
 
346
  delFirstBtn.click(
347
+ delete_first_conversation,
348
+ [current_model],
349
  [status_display],
350
  )
351
 
352
  delLastBtn.click(
353
+ delete_last_conversation,
354
+ [current_model, chatbot],
355
  [chatbot, status_display],
356
  show_progress=False
357
  )
 
359
  two_column.change(update_doc_config, [two_column], None)
360
 
361
  # LLM Models
362
+ keyTxt.change(set_key, [current_model, keyTxt], [status_display]).then(**get_usage_args)
363
  keyTxt.submit(**get_usage_args)
364
+ single_turn_checkbox.change(set_single_turn, [current_model, single_turn_checkbox], None)
365
+ model_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [current_model, status_display, lora_select_dropdown], show_progress=True)
366
+ lora_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [current_model, status_display], show_progress=True)
367
 
368
  # Template
369
+ systemPromptTxt.change(set_system_prompt, [current_model, systemPromptTxt], None)
370
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
371
  templateFileSelectDropdown.change(
372
  load_template,
 
383
 
384
  # S&L
385
  saveHistoryBtn.click(
386
+ save_chat_history,
387
+ [current_model, saveFileName, chatbot, user_name],
388
  downloadFile,
389
  show_progress=True,
390
  )
391
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
392
  exportMarkdownBtn.click(
393
+ export_markdown,
394
+ [current_model, saveFileName, chatbot, user_name],
395
  downloadFile,
396
  show_progress=True,
397
  )
 
400
  downloadFile.change(**load_history_from_file_args)
401
 
402
  # Advanced
403
+ max_context_length_slider.change(set_token_upper_limit, [current_model, max_context_length_slider], None)
404
+ temperature_slider.change(set_temperature, [current_model, temperature_slider], None)
405
+ top_p_slider.change(set_top_p, [current_model, top_p_slider], None)
406
+ n_choices_slider.change(set_n_choices, [current_model, n_choices_slider], None)
407
+ stop_sequence_txt.change(set_stop_sequence, [current_model, stop_sequence_txt], None)
408
+ max_generation_slider.change(set_max_tokens, [current_model, max_generation_slider], None)
409
+ presence_penalty_slider.change(set_presence_penalty, [current_model, presence_penalty_slider], None)
410
+ frequency_penalty_slider.change(set_frequency_penalty, [current_model, frequency_penalty_slider], None)
411
+ logit_bias_txt.change(set_logit_bias, [current_model, logit_bias_txt], None)
412
+ user_identifier_txt.change(set_user_identifier, [current_model, user_identifier_txt], None)
413
 
414
  default_btn.click(
415
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
modules/models.py CHANGED
@@ -207,51 +207,52 @@ class OpenAIClient(BaseLLMModel):
207
  continue
208
  if error_msg:
209
  raise Exception(error_msg)
210
-
211
 
212
  class ChatGLM_Client(BaseLLMModel):
213
  def __init__(self, model_name) -> None:
214
  super().__init__(model_name=model_name)
215
  from transformers import AutoTokenizer, AutoModel
216
  import torch
217
-
218
- system_name = platform.system()
219
- model_path = None
220
- if os.path.exists("models"):
221
- model_dirs = os.listdir("models")
222
- if model_name in model_dirs:
223
- model_path = f"models/{model_name}"
224
- if model_path is not None:
225
- model_source = model_path
226
- else:
227
- model_source = f"THUDM/{model_name}"
228
- self.tokenizer = AutoTokenizer.from_pretrained(
229
- model_source, trust_remote_code=True
230
- )
231
- quantified = False
232
- if "int4" in model_name:
233
- quantified = True
234
- if quantified:
235
- model = AutoModel.from_pretrained(
236
- model_source, trust_remote_code=True
237
- ).half()
238
- else:
239
- model = AutoModel.from_pretrained(
240
  model_source, trust_remote_code=True
241
- ).half()
242
- if torch.cuda.is_available():
243
- # run on CUDA
244
- logging.info("CUDA is available, using CUDA")
245
- model = model.cuda()
246
- # mps加速还存在一些问题,暂时不使用
247
- elif system_name == "Darwin" and model_path is not None and not quantified:
248
- logging.info("Running on macOS, using MPS")
249
- # running on macOS and model already downloaded
250
- model = model.to("mps")
251
- else:
252
- logging.info("GPU is not available, using CPU")
253
- model = model.eval()
254
- self.model = model
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  def _get_glm_style_input(self):
257
  history = [x["content"] for x in self.history]
@@ -265,13 +266,13 @@ class ChatGLM_Client(BaseLLMModel):
265
 
266
  def get_answer_at_once(self):
267
  history, query = self._get_glm_style_input()
268
- response, _ = self.model.chat(self.tokenizer, query, history=history)
269
  return response, len(response)
270
 
271
  def get_answer_stream_iter(self):
272
  history, query = self._get_glm_style_input()
273
- for response, history in self.model.stream_chat(
274
- self.tokenizer,
275
  query,
276
  history,
277
  max_length=self.token_upper_limit,
@@ -292,77 +293,53 @@ class LLaMA_Client(BaseLLMModel):
292
  from lmflow.pipeline.auto_pipeline import AutoPipeline
293
  from lmflow.models.auto_model import AutoModel
294
  from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
295
-
296
- model_path = None
297
- if os.path.exists("models"):
298
- model_dirs = os.listdir("models")
299
- if model_name in model_dirs:
300
- model_path = f"models/{model_name}"
301
- if model_path is not None:
302
- model_source = model_path
303
- else:
304
- model_source = f"decapoda-research/{model_name}"
305
- # raise Exception(f"models目录下没有这个模型: {model_name}")
306
- if lora_path is not None:
307
- lora_path = f"lora/{lora_path}"
308
- self.lora_path = lora_path
309
  self.max_generation_token = 1000
310
- pipeline_name = "inferencer"
311
- model_args = ModelArguments(
312
- model_name_or_path=model_source,
313
- lora_model_path=lora_path,
314
- model_type=None,
315
- config_overrides=None,
316
- config_name=None,
317
- tokenizer_name=None,
318
- cache_dir=None,
319
- use_fast_tokenizer=True,
320
- model_revision="main",
321
- use_auth_token=False,
322
- torch_dtype=None,
323
- use_lora=False,
324
- lora_r=8,
325
- lora_alpha=32,
326
- lora_dropout=0.1,
327
- use_ram_optimized_load=True,
328
- )
329
- pipeline_args = InferencerArguments(
330
- local_rank=0,
331
- random_seed=1,
332
- deepspeed="configs/ds_config_chatbot.json",
333
- mixed_precision="bf16",
334
- )
335
-
336
- with open(pipeline_args.deepspeed, "r") as f:
337
- ds_config = json.load(f)
338
-
339
- self.model = AutoModel.get_model(
340
- model_args,
341
- tune_strategy="none",
342
- ds_config=ds_config,
343
- )
344
-
345
  # We don't need input data
346
  data_args = DatasetArguments(dataset_path=None)
347
  self.dataset = Dataset(data_args)
348
 
349
- self.inferencer = AutoPipeline.get_pipeline(
350
- pipeline_name=pipeline_name,
351
- model_args=model_args,
352
- data_args=data_args,
353
- pipeline_args=pipeline_args,
354
- )
355
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  # Chats
357
- model_name = model_args.model_name_or_path
358
- if model_args.lora_model_path is not None:
359
- model_name += f" + {model_args.lora_model_path}"
360
 
361
  # context = (
362
  # "You are a helpful assistant who follows the given instructions"
363
  # " unconditionally."
364
  # )
365
- self.end_string = "\n\n"
366
 
367
  def _get_llama_style_input(self):
368
  history = []
@@ -382,8 +359,8 @@ class LLaMA_Client(BaseLLMModel):
382
  {"type": "text_only", "instances": [{"text": context}]}
383
  )
384
 
385
- output_dataset = self.inferencer.inference(
386
- model=self.model,
387
  dataset=input_dataset,
388
  max_new_tokens=self.max_generation_token,
389
  temperature=self.temperature,
@@ -400,8 +377,8 @@ class LLaMA_Client(BaseLLMModel):
400
  input_dataset = self.dataset.from_dict(
401
  {"type": "text_only", "instances": [{"text": context + partial_text}]}
402
  )
403
- output_dataset = self.inferencer.inference(
404
- model=self.model,
405
  dataset=input_dataset,
406
  max_new_tokens=step,
407
  temperature=self.temperature,
@@ -413,158 +390,62 @@ class LLaMA_Client(BaseLLMModel):
413
  yield partial_text
414
 
415
 
416
- class ModelManager:
417
- def __init__(self, **kwargs) -> None:
418
- self.model = None
419
- self.get_model(**kwargs)
420
-
421
- def get_model(
422
- self,
423
- model_name,
424
- lora_model_path=None,
425
- access_key=None,
426
- temperature=None,
427
- top_p=None,
428
- system_prompt=None,
429
- ) -> BaseLLMModel:
430
- msg = f"模型设置为了: {model_name}"
431
- if self.model is not None and model_name == self.model.model_name:
432
- # 如果模型名字一样,那么就不用重新加载模型
433
- # if (
434
- # lora_model_path is not None
435
- # and hasattr(self.model, "lora_path")
436
- # and lora_model_path == self.model.lora_path
437
- # or lora_model_path is None
438
- # and not hasattr(self.model, "lora_path")
439
- # ):
440
- logging.info(f"模型 {model_name} 已经加载,不需要重新加载")
441
- return msg
442
- model_type = ModelType.get_type(model_name)
443
- lora_selector_visibility = False
444
- lora_choices = []
445
- dont_change_lora_selector = False
446
- if model_type != ModelType.OpenAI:
447
- config.local_embedding = True
448
- del self.model
449
- model = None
450
- try:
451
- if model_type == ModelType.OpenAI:
452
- logging.info(f"正在加载OpenAI模型: {model_name}")
453
- model = OpenAIClient(
454
- model_name=model_name,
455
- api_key=access_key,
456
- system_prompt=system_prompt,
457
- temperature=temperature,
458
- top_p=top_p,
459
- )
460
- elif model_type == ModelType.ChatGLM:
461
- logging.info(f"正在加载ChatGLM模型: {model_name}")
462
- model = ChatGLM_Client(model_name)
463
- elif model_type == ModelType.LLaMA and lora_model_path == "":
464
- msg = f"现在请为 {model_name} 选择LoRA模型"
465
- logging.info(msg)
466
- lora_selector_visibility = True
467
- if os.path.isdir("lora"):
468
- lora_choices = get_file_names("lora", plain=True, filetypes=[""])
469
- lora_choices = ["No LoRA"] + lora_choices
470
- elif model_type == ModelType.LLaMA and lora_model_path != "":
471
- logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
472
- dont_change_lora_selector = True
473
- if lora_model_path == "No LoRA":
474
- lora_model_path = None
475
- msg += " + No LoRA"
476
- else:
477
- msg += f" + {lora_model_path}"
478
- model = LLaMA_Client(model_name, lora_model_path)
479
- elif model_type == ModelType.Unknown:
480
- raise ValueError(f"未知模型: {model_name}")
481
- logging.info(msg)
482
- except Exception as e:
483
- logging.error(e)
484
- msg = f"{STANDARD_ERROR_MSG}: {e}"
485
- self.model = model
486
- if dont_change_lora_selector:
487
- return msg
488
- else:
489
- return msg, gr.Dropdown.update(
490
- choices=lora_choices, visible=lora_selector_visibility
491
  )
492
-
493
- def predict(self, *args):
494
- iter = self.model.predict(*args)
495
- for i in iter:
496
- yield i
497
-
498
- def billing_info(self):
499
- return self.model.billing_info()
500
-
501
- def set_key(self, *args):
502
- return self.model.set_key(*args)
503
-
504
- def load_chat_history(self, *args):
505
- return self.model.load_chat_history(*args)
506
-
507
- def interrupt(self, *args):
508
- return self.model.interrupt(*args)
509
-
510
- def reset(self, *args):
511
- return self.model.reset(*args)
512
-
513
- def retry(self, *args):
514
- iter = self.model.retry(*args)
515
- for i in iter:
516
- yield i
517
-
518
- def delete_first_conversation(self, *args):
519
- return self.model.delete_first_conversation(*args)
520
-
521
- def delete_last_conversation(self, *args):
522
- return self.model.delete_last_conversation(*args)
523
-
524
- def set_system_prompt(self, *args):
525
- return self.model.set_system_prompt(*args)
526
-
527
- def save_chat_history(self, *args):
528
- return self.model.save_chat_history(*args)
529
-
530
- def export_markdown(self, *args):
531
- return self.model.export_markdown(*args)
532
-
533
- def load_chat_history(self, *args):
534
- return self.model.load_chat_history(*args)
535
-
536
- def set_token_upper_limit(self, *args):
537
- return self.model.set_token_upper_limit(*args)
538
-
539
- def set_temperature(self, *args):
540
- self.model.set_temperature(*args)
541
-
542
- def set_top_p(self, *args):
543
- self.model.set_top_p(*args)
544
-
545
- def set_n_choices(self, *args):
546
- self.model.set_n_choices(*args)
547
-
548
- def set_stop_sequence(self, *args):
549
- self.model.set_stop_sequence(*args)
550
-
551
- def set_max_tokens(self, *args):
552
- self.model.set_max_tokens(*args)
553
-
554
- def set_presence_penalty(self, *args):
555
- self.model.set_presence_penalty(*args)
556
-
557
- def set_frequency_penalty(self, *args):
558
- self.model.set_frequency_penalty(*args)
559
-
560
- def set_logit_bias(self, *args):
561
- self.model.set_logit_bias(*args)
562
-
563
- def set_user_identifier(self, *args):
564
- self.model.set_user_identifier(*args)
565
-
566
- def set_single_turn(self, *args):
567
- self.model.set_single_turn(*args)
568
 
569
 
570
  if __name__ == "__main__":
@@ -573,7 +454,7 @@ if __name__ == "__main__":
573
  # set logging level to debug
574
  logging.basicConfig(level=logging.DEBUG)
575
  # client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
576
- client = ModelManager(model_name="chatglm-6b-int4")
577
  chatbot = []
578
  stream = False
579
  # 测试账单功能
 
207
  continue
208
  if error_msg:
209
  raise Exception(error_msg)
210
+
211
 
212
  class ChatGLM_Client(BaseLLMModel):
213
  def __init__(self, model_name) -> None:
214
  super().__init__(model_name=model_name)
215
  from transformers import AutoTokenizer, AutoModel
216
  import torch
217
+ global CHATGLM_TOKENIZER, CHATGLM_MODEL
218
+ if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
219
+ system_name = platform.system()
220
+ model_path=None
221
+ if os.path.exists("models"):
222
+ model_dirs = os.listdir("models")
223
+ if model_name in model_dirs:
224
+ model_path = f"models/{model_name}"
225
+ if model_path is not None:
226
+ model_source = model_path
227
+ else:
228
+ model_source = f"THUDM/{model_name}"
229
+ CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(
 
 
 
 
 
 
 
 
 
 
230
  model_source, trust_remote_code=True
231
+ )
232
+ quantified = False
233
+ if "int4" in model_name:
234
+ quantified = True
235
+ if quantified:
236
+ model = AutoModel.from_pretrained(
237
+ model_source, trust_remote_code=True
238
+ ).half()
239
+ else:
240
+ model = AutoModel.from_pretrained(
241
+ model_source, trust_remote_code=True
242
+ ).half()
243
+ if torch.cuda.is_available():
244
+ # run on CUDA
245
+ logging.info("CUDA is available, using CUDA")
246
+ model = model.cuda()
247
+ # mps加速还存在一些问题,暂时不使用
248
+ elif system_name == "Darwin" and model_path is not None and not quantified:
249
+ logging.info("Running on macOS, using MPS")
250
+ # running on macOS and model already downloaded
251
+ model = model.to("mps")
252
+ else:
253
+ logging.info("GPU is not available, using CPU")
254
+ model = model.eval()
255
+ CHATGLM_MODEL = model
256
 
257
  def _get_glm_style_input(self):
258
  history = [x["content"] for x in self.history]
 
266
 
267
  def get_answer_at_once(self):
268
  history, query = self._get_glm_style_input()
269
+ response, _ = CHATGLM_MODEL.chat(CHATGLM_TOKENIZER, query, history=history)
270
  return response, len(response)
271
 
272
  def get_answer_stream_iter(self):
273
  history, query = self._get_glm_style_input()
274
+ for response, history in CHATGLM_MODEL.stream_chat(
275
+ CHATGLM_TOKENIZER,
276
  query,
277
  history,
278
  max_length=self.token_upper_limit,
 
293
  from lmflow.pipeline.auto_pipeline import AutoPipeline
294
  from lmflow.models.auto_model import AutoModel
295
  from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
296
+
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  self.max_generation_token = 1000
298
+ self.end_string = "\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  # We don't need input data
300
  data_args = DatasetArguments(dataset_path=None)
301
  self.dataset = Dataset(data_args)
302
 
303
+ global LLAMA_MODEL, LLAMA_INFERENCER
304
+ if LLAMA_MODEL is None or LLAMA_INFERENCER is None:
305
+ model_path = None
306
+ if os.path.exists("models"):
307
+ model_dirs = os.listdir("models")
308
+ if model_name in model_dirs:
309
+ model_path = f"models/{model_name}"
310
+ if model_path is not None:
311
+ model_source = model_path
312
+ else:
313
+ model_source = f"decapoda-research/{model_name}"
314
+ # raise Exception(f"models目录下没有这个模型: {model_name}")
315
+ if lora_path is not None:
316
+ lora_path = f"lora/{lora_path}"
317
+ model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
318
+ pipeline_args = InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
319
+
320
+ with open(pipeline_args.deepspeed, "r") as f:
321
+ ds_config = json.load(f)
322
+ LLAMA_MODEL = AutoModel.get_model(
323
+ model_args,
324
+ tune_strategy="none",
325
+ ds_config=ds_config,
326
+ )
327
+ LLAMA_INFERENCER = AutoPipeline.get_pipeline(
328
+ pipeline_name="inferencer",
329
+ model_args=model_args,
330
+ data_args=data_args,
331
+ pipeline_args=pipeline_args,
332
+ )
333
  # Chats
334
+ # model_name = model_args.model_name_or_path
335
+ # if model_args.lora_model_path is not None:
336
+ # model_name += f" + {model_args.lora_model_path}"
337
 
338
  # context = (
339
  # "You are a helpful assistant who follows the given instructions"
340
  # " unconditionally."
341
  # )
342
+
343
 
344
  def _get_llama_style_input(self):
345
  history = []
 
359
  {"type": "text_only", "instances": [{"text": context}]}
360
  )
361
 
362
+ output_dataset = LLAMA_INFERENCER.inference(
363
+ model=LLAMA_MODEL,
364
  dataset=input_dataset,
365
  max_new_tokens=self.max_generation_token,
366
  temperature=self.temperature,
 
377
  input_dataset = self.dataset.from_dict(
378
  {"type": "text_only", "instances": [{"text": context + partial_text}]}
379
  )
380
+ output_dataset = LLAMA_INFERENCER.inference(
381
+ model=LLAMA_MODEL,
382
  dataset=input_dataset,
383
  max_new_tokens=step,
384
  temperature=self.temperature,
 
390
  yield partial_text
391
 
392
 
393
+ def get_model(
394
+ model_name,
395
+ lora_model_path=None,
396
+ access_key=None,
397
+ temperature=None,
398
+ top_p=None,
399
+ system_prompt=None,
400
+ ) -> BaseLLMModel:
401
+ msg = f"模型设置为了: {model_name}"
402
+ model_type = ModelType.get_type(model_name)
403
+ lora_selector_visibility = False
404
+ lora_choices = []
405
+ dont_change_lora_selector = False
406
+ if model_type != ModelType.OpenAI:
407
+ config.local_embedding = True
408
+ # del current_model.model
409
+ model = None
410
+ try:
411
+ if model_type == ModelType.OpenAI:
412
+ logging.info(f"正在加载OpenAI模型: {model_name}")
413
+ model = OpenAIClient(
414
+ model_name=model_name,
415
+ api_key=access_key,
416
+ system_prompt=system_prompt,
417
+ temperature=temperature,
418
+ top_p=top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  )
420
+ elif model_type == ModelType.ChatGLM:
421
+ logging.info(f"正在加载ChatGLM模型: {model_name}")
422
+ model = ChatGLM_Client(model_name)
423
+ elif model_type == ModelType.LLaMA and lora_model_path == "":
424
+ msg = f"现在请为 {model_name} 选择LoRA模型"
425
+ logging.info(msg)
426
+ lora_selector_visibility = True
427
+ if os.path.isdir("lora"):
428
+ lora_choices = get_file_names("lora", plain=True, filetypes=[""])
429
+ lora_choices = ["No LoRA"] + lora_choices
430
+ elif model_type == ModelType.LLaMA and lora_model_path != "":
431
+ logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
432
+ dont_change_lora_selector = True
433
+ if lora_model_path == "No LoRA":
434
+ lora_model_path = None
435
+ msg += " + No LoRA"
436
+ else:
437
+ msg += f" + {lora_model_path}"
438
+ model = LLaMA_Client(model_name, lora_model_path)
439
+ elif model_type == ModelType.Unknown:
440
+ raise ValueError(f"未知模型: {model_name}")
441
+ logging.info(msg)
442
+ except Exception as e:
443
+ logging.error(e)
444
+ msg = f"{STANDARD_ERROR_MSG}: {e}"
445
+ if dont_change_lora_selector:
446
+ return model, msg
447
+ else:
448
+ return model, msg, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
 
450
 
451
  if __name__ == "__main__":
 
454
  # set logging level to debug
455
  logging.basicConfig(level=logging.DEBUG)
456
  # client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
457
+ client = get_model(model_name="chatglm-6b-int4")
458
  chatbot = []
459
  stream = False
460
  # 测试账单功能
modules/presets.py CHANGED
@@ -4,6 +4,11 @@ from pathlib import Path
4
 
5
  import gradio as gr
6
 
 
 
 
 
 
7
  # ChatGPT 设置
8
  INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
9
  API_HOST = "api.openai.com"
 
4
 
5
  import gradio as gr
6
 
7
+ CHATGLM_MODEL = None
8
+ CHATGLM_TOKENIZER = None
9
+ LLAMA_MODEL = None
10
+ LLAMA_INFERENCER = None
11
+
12
  # ChatGPT 设置
13
  INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
14
  API_HOST = "api.openai.com"
modules/utils.py CHANGED
@@ -33,6 +33,82 @@ if TYPE_CHECKING:
33
  class DataframeData(TypedDict):
34
  headers: List[str]
35
  data: List[List[str | int | bool]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  def count_token(message):
 
33
  class DataframeData(TypedDict):
34
  headers: List[str]
35
  data: List[List[str | int | bool]]
36
+
37
+ def predict(current_model, *args):
38
+ iter = current_model.predict(*args)
39
+ for i in iter:
40
+ yield i
41
+
42
+ def billing_info(current_model):
43
+ return current_model.billing_info()
44
+
45
+ def set_key(current_model, *args):
46
+ return current_model.set_key(*args)
47
+
48
+ def load_chat_history(current_model, *args):
49
+ return current_model.load_chat_history(*args)
50
+
51
+ def interrupt(current_model, *args):
52
+ return current_model.interrupt(*args)
53
+
54
+ def reset(current_model, *args):
55
+ return current_model.reset(*args)
56
+
57
+ def retry(current_model, *args):
58
+ iter = current_model.retry(*args)
59
+ for i in iter:
60
+ yield i
61
+
62
+ def delete_first_conversation(current_model, *args):
63
+ return current_model.delete_first_conversation(*args)
64
+
65
+ def delete_last_conversation(current_model, *args):
66
+ return current_model.delete_last_conversation(*args)
67
+
68
+ def set_system_prompt(current_model, *args):
69
+ return current_model.set_system_prompt(*args)
70
+
71
+ def save_chat_history(current_model, *args):
72
+ return current_model.save_chat_history(*args)
73
+
74
+ def export_markdown(current_model, *args):
75
+ return current_model.export_markdown(*args)
76
+
77
+ def load_chat_history(current_model, *args):
78
+ return current_model.load_chat_history(*args)
79
+
80
+ def set_token_upper_limit(current_model, *args):
81
+ return current_model.set_token_upper_limit(*args)
82
+
83
+ def set_temperature(current_model, *args):
84
+ current_model.set_temperature(*args)
85
+
86
+ def set_top_p(current_model, *args):
87
+ current_model.set_top_p(*args)
88
+
89
+ def set_n_choices(current_model, *args):
90
+ current_model.set_n_choices(*args)
91
+
92
+ def set_stop_sequence(current_model, *args):
93
+ current_model.set_stop_sequence(*args)
94
+
95
+ def set_max_tokens(current_model, *args):
96
+ current_model.set_max_tokens(*args)
97
+
98
+ def set_presence_penalty(current_model, *args):
99
+ current_model.set_presence_penalty(*args)
100
+
101
+ def set_frequency_penalty(current_model, *args):
102
+ current_model.set_frequency_penalty(*args)
103
+
104
+ def set_logit_bias(current_model, *args):
105
+ current_model.set_logit_bias(*args)
106
+
107
+ def set_user_identifier(current_model, *args):
108
+ current_model.set_user_identifier(*args)
109
+
110
+ def set_single_turn(current_model, *args):
111
+ current_model.set_single_turn(*args)
112
 
113
 
114
  def count_token(message):