Tuchuanhuhuhu commited on
Commit
03f9025
·
1 Parent(s): 1bfb00d

Added support for multi-modal Model: XMBot

Browse files
ChuanhuChatbot.py CHANGED
@@ -12,6 +12,7 @@ from modules.presets import *
12
  from modules.overwrites import *
13
  from modules.models import get_model
14
 
 
15
  gr.Chatbot.postprocess = postprocess
16
  PromptHelper.compact_text_chunks = compact_text_chunks
17
 
@@ -321,6 +322,8 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
321
  submitBtn.click(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
322
  submitBtn.click(**get_usage_args)
323
 
 
 
324
  emptyBtn.click(
325
  reset,
326
  inputs=[current_model],
 
12
  from modules.overwrites import *
13
  from modules.models import get_model
14
 
15
+ gr.Chatbot._postprocess_chat_messages = postprocess_chat_messages
16
  gr.Chatbot.postprocess = postprocess
17
  PromptHelper.compact_text_chunks = compact_text_chunks
18
 
 
322
  submitBtn.click(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
323
  submitBtn.click(**get_usage_args)
324
 
325
+ index_files.change(handle_file_upload, [current_model, index_files, chatbot], [index_files, chatbot, status_display])
326
+
327
  emptyBtn.click(
328
  reset,
329
  inputs=[current_model],
modules/base_model.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  import sys
9
  import requests
10
  import urllib3
 
11
 
12
  from tqdm import tqdm
13
  import colorama
@@ -28,6 +29,7 @@ class ModelType(Enum):
28
  OpenAI = 0
29
  ChatGLM = 1
30
  LLaMA = 2
 
31
 
32
  @classmethod
33
  def get_type(cls, model_name: str):
@@ -39,6 +41,8 @@ class ModelType(Enum):
39
  model_type = ModelType.ChatGLM
40
  elif "llama" in model_name_lower or "alpaca" in model_name_lower:
41
  model_type = ModelType.LLaMA
 
 
42
  else:
43
  model_type = ModelType.Unknown
44
  return model_type
@@ -164,10 +168,19 @@ class BaseLLMModel:
164
  status_text = self.token_message()
165
  return chatbot, status_text
166
 
167
- def prepare_inputs(self, inputs, use_websearch, files, reply_language):
168
- old_inputs = None
 
 
 
 
 
 
 
 
169
  display_append = []
170
  limited_context = False
 
171
  if files:
172
  from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
173
  from llama_index.indices.query.schema import QueryBundle
@@ -180,12 +193,11 @@ class BaseLLMModel:
180
  OpenAIEmbedding,
181
  )
182
  limited_context = True
183
- old_inputs = inputs
184
- msg = "加载索引中……(这可能需要几分钟)"
185
  logging.info(msg)
186
  # yield chatbot + [(inputs, "")], msg
187
  index = construct_index(self.api_key, file_src=files)
188
- assert index is not None, "索引构建失败"
189
  msg = "索引获取成功,生成回答中……"
190
  logging.info(msg)
191
  if local_embedding or self.model_type != ModelType.OpenAI:
@@ -212,22 +224,21 @@ class BaseLLMModel:
212
  vector_store=index._vector_store,
213
  docstore=index._docstore,
214
  )
215
- query_bundle = QueryBundle(inputs)
216
  nodes = query_object.retrieve(query_bundle)
217
  reference_results = [n.node.text for n in nodes]
218
  reference_results = add_source_numbers(reference_results, use_source=False)
219
  display_append = add_details(reference_results)
220
  display_append = "\n\n" + "".join(display_append)
221
- inputs = (
222
  replace_today(PROMPT_TEMPLATE)
223
- .replace("{query_str}", inputs)
224
  .replace("{context_str}", "\n\n".join(reference_results))
225
  .replace("{reply_language}", reply_language)
226
  )
227
  elif use_websearch:
228
  limited_context = True
229
- search_results = ddg(inputs, max_results=5)
230
- old_inputs = inputs
231
  reference_results = []
232
  for idx, result in enumerate(search_results):
233
  logging.debug(f"搜索结果{idx + 1}:{result}")
@@ -238,15 +249,15 @@ class BaseLLMModel:
238
  )
239
  reference_results = add_source_numbers(reference_results)
240
  display_append = "\n\n" + "".join(display_append)
241
- inputs = (
242
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
243
- .replace("{query}", inputs)
244
  .replace("{web_results}", "\n\n".join(reference_results))
245
  .replace("{reply_language}", reply_language)
246
  )
247
  else:
248
  display_append = ""
249
- return limited_context, old_inputs, display_append, inputs
250
 
251
  def predict(
252
  self,
@@ -259,16 +270,17 @@ class BaseLLMModel:
259
  should_check_token_count=True,
260
  ): # repetition_penalty, top_k
261
 
262
-
263
  logging.info(
264
  "输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
265
  )
266
  if should_check_token_count:
267
- yield chatbot + [(inputs, "")], "开始生成回答……"
268
  if reply_language == "跟随问题语言(不稳定)":
269
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
270
 
271
- limited_context, old_inputs, display_append, inputs = self.prepare_inputs(inputs=inputs, use_websearch=use_websearch, files=files, reply_language=reply_language)
 
272
 
273
  if (
274
  self.need_api_key and
@@ -303,7 +315,7 @@ class BaseLLMModel:
303
  iter = self.stream_next_chatbot(
304
  inputs,
305
  chatbot,
306
- fake_input=old_inputs,
307
  display_append=display_append,
308
  )
309
  for chatbot, status_text in iter:
@@ -313,11 +325,12 @@ class BaseLLMModel:
313
  chatbot, status_text = self.next_chatbot_at_once(
314
  inputs,
315
  chatbot,
316
- fake_input=old_inputs,
317
  display_append=display_append,
318
  )
319
  yield chatbot, status_text
320
  except Exception as e:
 
321
  status_text = STANDARD_ERROR_MSG + str(e)
322
  yield chatbot, status_text
323
 
 
8
  import sys
9
  import requests
10
  import urllib3
11
+ import traceback
12
 
13
  from tqdm import tqdm
14
  import colorama
 
29
  OpenAI = 0
30
  ChatGLM = 1
31
  LLaMA = 2
32
+ XMBot = 3
33
 
34
  @classmethod
35
  def get_type(cls, model_name: str):
 
41
  model_type = ModelType.ChatGLM
42
  elif "llama" in model_name_lower or "alpaca" in model_name_lower:
43
  model_type = ModelType.LLaMA
44
+ elif "xmbot" in model_name_lower:
45
+ model_type = ModelType.XMBot
46
  else:
47
  model_type = ModelType.Unknown
48
  return model_type
 
168
  status_text = self.token_message()
169
  return chatbot, status_text
170
 
171
+ def handle_file_upload(self, files, chatbot):
172
+ """if the model accepts multi modal input, implement this function"""
173
+ status = gr.Markdown.update()
174
+ if files:
175
+ construct_index(self.api_key, file_src=files)
176
+ status = "索引构建完成"
177
+ return gr.Files.update(), chatbot, status
178
+
179
+ def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
180
+ fake_inputs = None
181
  display_append = []
182
  limited_context = False
183
+ fake_inputs = real_inputs
184
  if files:
185
  from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
186
  from llama_index.indices.query.schema import QueryBundle
 
193
  OpenAIEmbedding,
194
  )
195
  limited_context = True
196
+ msg = "加载索引中……"
 
197
  logging.info(msg)
198
  # yield chatbot + [(inputs, "")], msg
199
  index = construct_index(self.api_key, file_src=files)
200
+ assert index is not None, "获取索引失败"
201
  msg = "索引获取成功,生成回答中……"
202
  logging.info(msg)
203
  if local_embedding or self.model_type != ModelType.OpenAI:
 
224
  vector_store=index._vector_store,
225
  docstore=index._docstore,
226
  )
227
+ query_bundle = QueryBundle(real_inputs)
228
  nodes = query_object.retrieve(query_bundle)
229
  reference_results = [n.node.text for n in nodes]
230
  reference_results = add_source_numbers(reference_results, use_source=False)
231
  display_append = add_details(reference_results)
232
  display_append = "\n\n" + "".join(display_append)
233
+ real_inputs = (
234
  replace_today(PROMPT_TEMPLATE)
235
+ .replace("{query_str}", real_inputs)
236
  .replace("{context_str}", "\n\n".join(reference_results))
237
  .replace("{reply_language}", reply_language)
238
  )
239
  elif use_websearch:
240
  limited_context = True
241
+ search_results = ddg(real_inputs, max_results=5)
 
242
  reference_results = []
243
  for idx, result in enumerate(search_results):
244
  logging.debug(f"搜索结果{idx + 1}:{result}")
 
249
  )
250
  reference_results = add_source_numbers(reference_results)
251
  display_append = "\n\n" + "".join(display_append)
252
+ real_inputs = (
253
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
254
+ .replace("{query}", real_inputs)
255
  .replace("{web_results}", "\n\n".join(reference_results))
256
  .replace("{reply_language}", reply_language)
257
  )
258
  else:
259
  display_append = ""
260
+ return limited_context, fake_inputs, display_append, real_inputs, chatbot
261
 
262
  def predict(
263
  self,
 
270
  should_check_token_count=True,
271
  ): # repetition_penalty, top_k
272
 
273
+ status_text = "开始生成回答……"
274
  logging.info(
275
  "输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
276
  )
277
  if should_check_token_count:
278
+ yield chatbot + [(inputs, "")], status_text
279
  if reply_language == "跟随问题语言(不稳定)":
280
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
281
 
282
+ limited_context, fake_inputs, display_append, inputs, chatbot = self.prepare_inputs(real_inputs=inputs, use_websearch=use_websearch, files=files, reply_language=reply_language, chatbot=chatbot)
283
+ yield chatbot + [(fake_inputs, "")], status_text
284
 
285
  if (
286
  self.need_api_key and
 
315
  iter = self.stream_next_chatbot(
316
  inputs,
317
  chatbot,
318
+ fake_input=fake_inputs,
319
  display_append=display_append,
320
  )
321
  for chatbot, status_text in iter:
 
325
  chatbot, status_text = self.next_chatbot_at_once(
326
  inputs,
327
  chatbot,
328
+ fake_input=fake_inputs,
329
  display_append=display_append,
330
  )
331
  yield chatbot, status_text
332
  except Exception as e:
333
+ traceback.print_exc()
334
  status_text = STANDARD_ERROR_MSG + str(e)
335
  yield chatbot, status_text
336
 
modules/models.py CHANGED
@@ -16,6 +16,7 @@ from duckduckgo_search import ddg
16
  import asyncio
17
  import aiohttp
18
  from enum import Enum
 
19
 
20
  from .presets import *
21
  from .llama_func import *
@@ -75,7 +76,8 @@ class OpenAIClient(BaseLLMModel):
75
  def billing_info(self):
76
  try:
77
  curr_time = datetime.datetime.now()
78
- last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d")
 
79
  first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
80
  usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
81
  try:
@@ -112,7 +114,8 @@ class OpenAIClient(BaseLLMModel):
112
  openai_api_key = self.api_key
113
  system_prompt = self.system_prompt
114
  history = self.history
115
- logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
 
116
  headers = {
117
  "Content-Type": "application/json",
118
  "Authorization": f"Bearer {openai_api_key}",
@@ -217,7 +220,7 @@ class ChatGLM_Client(BaseLLMModel):
217
  global CHATGLM_TOKENIZER, CHATGLM_MODEL
218
  if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
219
  system_name = platform.system()
220
- model_path=None
221
  if os.path.exists("models"):
222
  model_dirs = os.listdir("models")
223
  if model_name in model_dirs:
@@ -257,16 +260,19 @@ class ChatGLM_Client(BaseLLMModel):
257
  def _get_glm_style_input(self):
258
  history = [x["content"] for x in self.history]
259
  query = history.pop()
260
- logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
 
261
  assert (
262
  len(history) % 2 == 0
263
  ), f"History should be even length. current history is: {history}"
264
- history = [[history[i], history[i + 1]] for i in range(0, len(history), 2)]
 
265
  return history, query
266
 
267
  def get_answer_at_once(self):
268
  history, query = self._get_glm_style_input()
269
- response, _ = CHATGLM_MODEL.chat(CHATGLM_TOKENIZER, query, history=history)
 
270
  return response, len(response)
271
 
272
  def get_answer_stream_iter(self):
@@ -315,8 +321,10 @@ class LLaMA_Client(BaseLLMModel):
315
  # raise Exception(f"models目录下没有这个模型: {model_name}")
316
  if lora_path is not None:
317
  lora_path = f"lora/{lora_path}"
318
- model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
319
- pipeline_args = InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
 
 
320
 
321
  with open(pipeline_args.deepspeed, "r") as f:
322
  ds_config = json.load(f)
@@ -341,7 +349,6 @@ class LLaMA_Client(BaseLLMModel):
341
  # " unconditionally."
342
  # )
343
 
344
-
345
  def _get_llama_style_input(self):
346
  history = []
347
  instruction = ""
@@ -379,7 +386,8 @@ class LLaMA_Client(BaseLLMModel):
379
  step = 1
380
  for _ in range(0, self.max_generation_token, step):
381
  input_dataset = self.dataset.from_dict(
382
- {"type": "text_only", "instances": [{"text": context + partial_text}]}
 
383
  )
384
  output_dataset = LLAMA_INFERENCER.inference(
385
  model=LLAMA_MODEL,
@@ -394,6 +402,94 @@ class LLaMA_Client(BaseLLMModel):
394
  yield partial_text
395
 
396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  def get_model(
398
  model_name,
399
  lora_model_path=None,
@@ -429,7 +525,8 @@ def get_model(
429
  logging.info(msg)
430
  lora_selector_visibility = True
431
  if os.path.isdir("lora"):
432
- lora_choices = get_file_names("lora", plain=True, filetypes=[""])
 
433
  lora_choices = ["No LoRA"] + lora_choices
434
  elif model_type == ModelType.LLaMA and lora_model_path != "":
435
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
@@ -440,6 +537,8 @@ def get_model(
440
  else:
441
  msg += f" + {lora_model_path}"
442
  model = LLaMA_Client(model_name, lora_model_path)
 
 
443
  elif model_type == ModelType.Unknown:
444
  raise ValueError(f"未知模型: {model_name}")
445
  logging.info(msg)
 
16
  import asyncio
17
  import aiohttp
18
  from enum import Enum
19
+ import uuid
20
 
21
  from .presets import *
22
  from .llama_func import *
 
76
  def billing_info(self):
77
  try:
78
  curr_time = datetime.datetime.now()
79
+ last_day_of_month = get_last_day_of_month(
80
+ curr_time).strftime("%Y-%m-%d")
81
  first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
82
  usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
83
  try:
 
114
  openai_api_key = self.api_key
115
  system_prompt = self.system_prompt
116
  history = self.history
117
+ logging.debug(colorama.Fore.YELLOW +
118
+ f"{history}" + colorama.Fore.RESET)
119
  headers = {
120
  "Content-Type": "application/json",
121
  "Authorization": f"Bearer {openai_api_key}",
 
220
  global CHATGLM_TOKENIZER, CHATGLM_MODEL
221
  if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
222
  system_name = platform.system()
223
+ model_path = None
224
  if os.path.exists("models"):
225
  model_dirs = os.listdir("models")
226
  if model_name in model_dirs:
 
260
  def _get_glm_style_input(self):
261
  history = [x["content"] for x in self.history]
262
  query = history.pop()
263
+ logging.debug(colorama.Fore.YELLOW +
264
+ f"{history}" + colorama.Fore.RESET)
265
  assert (
266
  len(history) % 2 == 0
267
  ), f"History should be even length. current history is: {history}"
268
+ history = [[history[i], history[i + 1]]
269
+ for i in range(0, len(history), 2)]
270
  return history, query
271
 
272
  def get_answer_at_once(self):
273
  history, query = self._get_glm_style_input()
274
+ response, _ = CHATGLM_MODEL.chat(
275
+ CHATGLM_TOKENIZER, query, history=history)
276
  return response, len(response)
277
 
278
  def get_answer_stream_iter(self):
 
321
  # raise Exception(f"models目录下没有这个模型: {model_name}")
322
  if lora_path is not None:
323
  lora_path = f"lora/{lora_path}"
324
+ model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None,
325
+ use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
326
+ pipeline_args = InferencerArguments(
327
+ local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
328
 
329
  with open(pipeline_args.deepspeed, "r") as f:
330
  ds_config = json.load(f)
 
349
  # " unconditionally."
350
  # )
351
 
 
352
  def _get_llama_style_input(self):
353
  history = []
354
  instruction = ""
 
386
  step = 1
387
  for _ in range(0, self.max_generation_token, step):
388
  input_dataset = self.dataset.from_dict(
389
+ {"type": "text_only", "instances": [
390
+ {"text": context + partial_text}]}
391
  )
392
  output_dataset = LLAMA_INFERENCER.inference(
393
  model=LLAMA_MODEL,
 
402
  yield partial_text
403
 
404
 
405
+ class XMBot_Client(BaseLLMModel):
406
+ def __init__(self, api_key):
407
+ super().__init__(model_name="xmbot")
408
+ self.api_key = api_key
409
+ self.session_id = None
410
+ self.reset()
411
+ self.image_bytes = None
412
+ self.image_path = None
413
+ self.xm_history = []
414
+ self.url = "https://xmbot.net/web"
415
+
416
+ def reset(self):
417
+ self.session_id = str(uuid.uuid4())
418
+ return [], "已重置"
419
+
420
+ def try_read_image(self, filepath):
421
+ import base64
422
+
423
+ def is_image_file(filepath):
424
+ # 判断文件是否为图片
425
+ valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
426
+ file_extension = os.path.splitext(filepath)[1].lower()
427
+ return file_extension in valid_image_extensions
428
+
429
+ def read_image_as_bytes(filepath):
430
+ # 读取图片文件并返回比特流
431
+ with open(filepath, "rb") as f:
432
+ image_bytes = f.read()
433
+ return image_bytes
434
+
435
+ if is_image_file(filepath):
436
+ logging.info(f"读取图片文件: {filepath}")
437
+ image_bytes = read_image_as_bytes(filepath)
438
+ base64_encoded_image = base64.b64encode(image_bytes).decode()
439
+ self.image_bytes = base64_encoded_image
440
+ self.image_path = filepath
441
+ else:
442
+ self.image_bytes = None
443
+ self.image_path = None
444
+
445
+ def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
446
+ fake_inputs = real_inputs
447
+ display_append = ""
448
+ limited_context = False
449
+ return limited_context, fake_inputs, display_append, real_inputs, chatbot
450
+
451
+ def handle_file_upload(self, files, chatbot):
452
+ """if the model accepts multi modal input, implement this function"""
453
+ if files:
454
+ for file in files:
455
+ if file.name:
456
+ logging.info(f"尝试读取图像: {file.name}")
457
+ self.try_read_image(file.name)
458
+ if self.image_path is not None:
459
+ chatbot = chatbot + [((self.image_path,), None)]
460
+ if self.image_bytes is not None:
461
+ logging.info("使用图片作为输入")
462
+ conv_id = str(uuid.uuid4())
463
+ data = {
464
+ "user_id": self.api_key,
465
+ "session_id": self.session_id,
466
+ "uuid": conv_id,
467
+ "data_type": "imgbase64",
468
+ "data": self.image_bytes
469
+ }
470
+ # response = requests.post(self.url, json=data)
471
+ # response = json.loads(response.text)
472
+ # logging.info(f"图片回复: {response['data']}")
473
+ logging.info("发送了图片")
474
+ return None, chatbot, None
475
+
476
+ def get_answer_at_once(self):
477
+ question = self.history[-1]["content"]
478
+ conv_id = str(uuid.uuid4())
479
+ data = {
480
+ "user_id": self.api_key,
481
+ "session_id": self.session_id,
482
+ "uuid": conv_id,
483
+ "data_type": "text",
484
+ "data": question
485
+ }
486
+ response = requests.post(self.url, json=data)
487
+ response = json.loads(response.text)
488
+ return response["data"], len(response["data"])
489
+
490
+
491
+
492
+
493
  def get_model(
494
  model_name,
495
  lora_model_path=None,
 
525
  logging.info(msg)
526
  lora_selector_visibility = True
527
  if os.path.isdir("lora"):
528
+ lora_choices = get_file_names(
529
+ "lora", plain=True, filetypes=[""])
530
  lora_choices = ["No LoRA"] + lora_choices
531
  elif model_type == ModelType.LLaMA and lora_model_path != "":
532
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
 
537
  else:
538
  msg += f" + {lora_model_path}"
539
  model = LLaMA_Client(model_name, lora_model_path)
540
+ elif model_type == ModelType.XMBot:
541
+ model = XMBot_Client(api_key=access_key)
542
  elif model_type == ModelType.Unknown:
543
  raise ValueError(f"未知模型: {model_name}")
544
  logging.info(msg)
modules/overwrites.py CHANGED
@@ -4,6 +4,7 @@ import logging
4
  from llama_index import Prompt
5
  from typing import List, Tuple
6
  import mdtex2html
 
7
 
8
  from modules.presets import *
9
  from modules.llama_func import *
@@ -20,23 +21,60 @@ def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[st
20
 
21
 
22
  def postprocess(
23
- self, y: List[Tuple[str | None, str | None]]
24
- ) -> List[Tuple[str | None, str | None]]:
25
- """
26
- Parameters:
27
- y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
28
- Returns:
29
- List of tuples representing the message and response. Each message and response will be a string of HTML.
30
- """
31
- if y is None or y == []:
32
- return []
33
- user, bot = y[-1]
34
- if not detect_converted_mark(user):
35
- user = convert_asis(user)
36
- if not detect_converted_mark(bot):
37
- bot = convert_mdtext(bot)
38
- y[-1] = (user, bot)
39
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r", encoding="utf-8") as f2:
42
  customJS = f.read()
 
4
  from llama_index import Prompt
5
  from typing import List, Tuple
6
  import mdtex2html
7
+ from gradio_client import utils as client_utils
8
 
9
  from modules.presets import *
10
  from modules.llama_func import *
 
21
 
22
 
23
  def postprocess(
24
+ self,
25
+ y: List[List[str | Tuple[str] | Tuple[str, str] | None] | Tuple],
26
+ ) -> List[List[str | Dict | None]]:
27
+ """
28
+ Parameters:
29
+ y: List of lists representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
30
+ Returns:
31
+ List of lists representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. Or None if the message is not to be displayed.
32
+ """
33
+ if y is None:
34
+ return []
35
+ processed_messages = []
36
+ for message_pair in y:
37
+ assert isinstance(
38
+ message_pair, (tuple, list)
39
+ ), f"Expected a list of lists or list of tuples. Received: {message_pair}"
40
+ assert (
41
+ len(message_pair) == 2
42
+ ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
43
+
44
+ processed_messages.append(
45
+ [
46
+ self._postprocess_chat_messages(message_pair[0], "user"),
47
+ self._postprocess_chat_messages(message_pair[1], "bot"),
48
+ ]
49
+ )
50
+ return processed_messages
51
+
52
+ def postprocess_chat_messages(
53
+ self, chat_message: str | Tuple | List | None, message_type: str
54
+ ) -> str | Dict | None:
55
+ if chat_message is None:
56
+ return None
57
+ elif isinstance(chat_message, (tuple, list)):
58
+ filepath = chat_message[0]
59
+ mime_type = client_utils.get_mimetype(filepath)
60
+ filepath = self.make_temp_copy_if_needed(filepath)
61
+ return {
62
+ "name": filepath,
63
+ "mime_type": mime_type,
64
+ "alt_text": chat_message[1] if len(chat_message) > 1 else None,
65
+ "data": None, # These last two fields are filled in by the frontend
66
+ "is_file": True,
67
+ }
68
+ elif isinstance(chat_message, str):
69
+ if message_type == "bot":
70
+ if not detect_converted_mark(chat_message):
71
+ chat_message = convert_mdtext(chat_message)
72
+ elif message_type == "user":
73
+ if not detect_converted_mark(chat_message):
74
+ chat_message = convert_asis(chat_message)
75
+ return chat_message
76
+ else:
77
+ raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
78
 
79
  with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r", encoding="utf-8") as f2:
80
  customJS = f.read()
modules/presets.py CHANGED
@@ -29,7 +29,7 @@ PROXY_ERROR_MSG = "代理错误,无法获取对话。" # 代理错误
29
  SSL_ERROR_PROMPT = "SSL错误,无法获取对话。" # SSL 错误
30
  NO_APIKEY_MSG = "API key为空,请检查是否输入正确。" # API key 长度不足 51 位
31
  NO_INPUT_MSG = "请输入对话内容。" # 未输入对话内容
32
- BILLING_NOT_APPLICABLE_MSG = "模型本地运行中" # 本地运行的模型返回的账单信息
33
 
34
  TIMEOUT_STREAMING = 60 # 流式对话时的超时时间
35
  TIMEOUT_ALL = 200 # 非流式对话时的超时时间
@@ -72,6 +72,7 @@ MODELS = [
72
  "gpt-4-0314",
73
  "gpt-4-32k",
74
  "gpt-4-32k-0314",
 
75
  "chatglm-6b",
76
  "chatglm-6b-int4",
77
  "chatglm-6b-int4-qe",
@@ -85,6 +86,8 @@ MODELS = [
85
  "llama-65b-hf",
86
  ] # 可选的模型
87
 
 
 
88
  os.makedirs("models", exist_ok=True)
89
  os.makedirs("lora", exist_ok=True)
90
  os.makedirs("history", exist_ok=True)
@@ -93,8 +96,6 @@ for dir_name in os.listdir("models"):
93
  if dir_name not in MODELS:
94
  MODELS.append(dir_name)
95
 
96
- DEFAULT_MODEL = 0 # 默认的模型在MODELS中的序号,从0开始数
97
-
98
  MODEL_TOKEN_LIMIT = {
99
  "gpt-3.5-turbo": 4096,
100
  "gpt-3.5-turbo-0301": 4096,
 
29
  SSL_ERROR_PROMPT = "SSL错误,无法获取对话。" # SSL 错误
30
  NO_APIKEY_MSG = "API key为空,请检查是否输入正确。" # API key 长度不足 51 位
31
  NO_INPUT_MSG = "请输入对话内容。" # 未输入对话内容
32
+ BILLING_NOT_APPLICABLE_MSG = "账单信息不适用" # 本地运行的模型返回的账单信息
33
 
34
  TIMEOUT_STREAMING = 60 # 流式对话时的超时时间
35
  TIMEOUT_ALL = 200 # 非流式对话时的超时时间
 
72
  "gpt-4-0314",
73
  "gpt-4-32k",
74
  "gpt-4-32k-0314",
75
+ "xmbot",
76
  "chatglm-6b",
77
  "chatglm-6b-int4",
78
  "chatglm-6b-int4-qe",
 
86
  "llama-65b-hf",
87
  ] # 可选的模型
88
 
89
+ DEFAULT_MODEL = 0 # 默认的模型在MODELS中的序号,从0开始数
90
+
91
  os.makedirs("models", exist_ok=True)
92
  os.makedirs("lora", exist_ok=True)
93
  os.makedirs("history", exist_ok=True)
 
96
  if dir_name not in MODELS:
97
  MODELS.append(dir_name)
98
 
 
 
99
  MODEL_TOKEN_LIMIT = {
100
  "gpt-3.5-turbo": 4096,
101
  "gpt-3.5-turbo-0301": 4096,
modules/utils.py CHANGED
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
33
  class DataframeData(TypedDict):
34
  headers: List[str]
35
  data: List[List[str | int | bool]]
36
-
37
  def predict(current_model, *args):
38
  iter = current_model.predict(*args)
39
  for i in iter:
@@ -110,6 +110,9 @@ def set_user_identifier(current_model, *args):
110
  def set_single_turn(current_model, *args):
111
  current_model.set_single_turn(*args)
112
 
 
 
 
113
 
114
  def count_token(message):
115
  encoding = tiktoken.get_encoding("cl100k_base")
@@ -197,10 +200,13 @@ def convert_asis(userinput):
197
 
198
 
199
  def detect_converted_mark(userinput):
200
- if userinput.endswith(ALREADY_CONVERTED_MARK):
 
 
 
 
 
201
  return True
202
- else:
203
- return False
204
 
205
 
206
  def detect_language(code):
 
33
  class DataframeData(TypedDict):
34
  headers: List[str]
35
  data: List[List[str | int | bool]]
36
+
37
  def predict(current_model, *args):
38
  iter = current_model.predict(*args)
39
  for i in iter:
 
110
  def set_single_turn(current_model, *args):
111
  current_model.set_single_turn(*args)
112
 
113
+ def handle_file_upload(current_model, *args):
114
+ return current_model.handle_file_upload(*args)
115
+
116
 
117
  def count_token(message):
118
  encoding = tiktoken.get_encoding("cl100k_base")
 
200
 
201
 
202
  def detect_converted_mark(userinput):
203
+ try:
204
+ if userinput.endswith(ALREADY_CONVERTED_MARK):
205
+ return True
206
+ else:
207
+ return False
208
+ except:
209
  return True
 
 
210
 
211
 
212
  def detect_language(code):