KevinHuSh commited on
Commit
c1bdfb8
·
1 Parent(s): 2ef1d8e

add Moonshot, debug my_llm (#126)

Browse files
api/apps/conversation_app.py CHANGED
@@ -309,13 +309,13 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
309
  # compose markdown table
310
  clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|原文|" if docid_idx and docid_idx else "|")
311
  line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
312
- line = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}\|", "|", line)
313
  rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
314
  if not docid_idx or not docnm_idx:
315
  chat_logger.warning("SQL missing field: " + sql)
316
  return "\n".join([clmns, line, "\n".join(rows)]), []
317
 
318
  rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
 
319
  docid_idx = list(docid_idx)[0]
320
  docnm_idx = list(docnm_idx)[0]
321
  return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
 
309
  # compose markdown table
310
  clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|原文|" if docid_idx and docid_idx else "|")
311
  line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
 
312
  rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
313
  if not docid_idx or not docnm_idx:
314
  chat_logger.warning("SQL missing field: " + sql)
315
  return "\n".join([clmns, line, "\n".join(rows)]), []
316
 
317
  rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
318
+ rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
319
  docid_idx = list(docid_idx)[0]
320
  docnm_idx = list(docnm_idx)[0]
321
  return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
api/apps/llm_app.py CHANGED
@@ -39,36 +39,40 @@ def factories():
39
  def set_api_key():
40
  req = request.json
41
  # test if api key works
 
 
42
  msg = ""
43
- for llm in LLMService.query(fid=req["llm_factory"]):
44
  if llm.model_type == LLMType.EMBEDDING.value:
45
- mdl = EmbeddingModel[req["llm_factory"]](
46
  req["api_key"], llm.llm_name)
47
  try:
48
  arr, tc = mdl.encode(["Test if the api key is available"])
49
  if len(arr[0]) == 0 or tc ==0: raise Exception("Fail")
50
  except Exception as e:
51
  msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
52
- elif llm.model_type == LLMType.CHAT.value:
53
- mdl = ChatModel[req["llm_factory"]](
54
  req["api_key"], llm.llm_name)
55
  try:
56
  m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
57
  if not tc: raise Exception(m)
 
58
  except Exception as e:
59
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e)
60
 
61
  if msg: return get_data_error_result(retmsg=msg)
62
 
63
  llm = {
64
- "tenant_id": current_user.id,
65
- "llm_factory": req["llm_factory"],
66
  "api_key": req["api_key"]
67
  }
68
  for n in ["model_type", "llm_name"]:
69
  if n in req: llm[n] = req[n]
70
 
71
- TenantLLMService.filter_update([TenantLLM.tenant_id==llm["tenant_id"], TenantLLM.llm_factory==llm["llm_factory"]], llm)
 
 
 
72
  return get_json_result(data=True)
73
 
74
 
 
39
  def set_api_key():
40
  req = request.json
41
  # test if api key works
42
+ chat_passed = False
43
+ factory = req["llm_factory"]
44
  msg = ""
45
+ for llm in LLMService.query(fid=factory):
46
  if llm.model_type == LLMType.EMBEDDING.value:
47
+ mdl = EmbeddingModel[factory](
48
  req["api_key"], llm.llm_name)
49
  try:
50
  arr, tc = mdl.encode(["Test if the api key is available"])
51
  if len(arr[0]) == 0 or tc ==0: raise Exception("Fail")
52
  except Exception as e:
53
  msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
54
+ elif not chat_passed and llm.model_type == LLMType.CHAT.value:
55
+ mdl = ChatModel[factory](
56
  req["api_key"], llm.llm_name)
57
  try:
58
  m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
59
  if not tc: raise Exception(m)
60
+ chat_passed = True
61
  except Exception as e:
62
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e)
63
 
64
  if msg: return get_data_error_result(retmsg=msg)
65
 
66
  llm = {
 
 
67
  "api_key": req["api_key"]
68
  }
69
  for n in ["model_type", "llm_name"]:
70
  if n in req: llm[n] = req[n]
71
 
72
+ if not TenantLLMService.filter_update([TenantLLM.tenant_id==current_user.id, TenantLLM.llm_factory==factory], llm):
73
+ for llm in LLMService.query(fid=factory):
74
+ TenantLLMService.save(tenant_id=current_user.id, llm_factory=factory, llm_name=llm.llm_name, model_type=llm.model_type, api_key=req["api_key"])
75
+
76
  return get_json_result(data=True)
77
 
78
 
api/db/db_models.py CHANGED
@@ -429,7 +429,7 @@ class LLMFactories(DataBaseModel):
429
 
430
  class LLM(DataBaseModel):
431
  # LLMs dictionary
432
- llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True)
433
  model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
434
  fid = CharField(max_length=128, null=False, help_text="LLM factory id")
435
  max_tokens = IntegerField(default=0)
 
429
 
430
  class LLM(DataBaseModel):
431
  # LLMs dictionary
432
+ llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True, primary_key=True)
433
  model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
434
  fid = CharField(max_length=128, null=False, help_text="LLM factory id")
435
  max_tokens = IntegerField(default=0)
api/db/init_data.py CHANGED
@@ -73,41 +73,41 @@ def init_superuser():
73
  print("\33[91m【ERROR】\33[0m:", " '{}' dosen't work!".format(tenant["embd_id"]))
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def init_llm_factory():
77
- factory_infos = [{
78
- "name": "OpenAI",
79
- "logo": "",
80
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
81
- "status": "1",
82
- },{
83
- "name": "通义千问",
84
- "logo": "",
85
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
86
- "status": "1",
87
- },{
88
- "name": "智谱AI",
89
- "logo": "",
90
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
91
- "status": "1",
92
- },
93
- {
94
- "name": "Local",
95
- "logo": "",
96
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
97
- "status": "1",
98
- },{
99
- "name": "Moonshot",
100
- "logo": "",
101
- "tags": "LLM,TEXT EMBEDDING",
102
- "status": "1",
103
- }
104
- # {
105
- # "name": "文心一言",
106
- # "logo": "",
107
- # "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
108
- # "status": "1",
109
- # },
110
- ]
111
  llm_infos = [
112
  # ---------------------- OpenAI ------------------------
113
  {
@@ -260,21 +260,30 @@ def init_llm_factory():
260
  },
261
  ]
262
  for info in factory_infos:
263
- LLMFactoriesService.save(**info)
 
 
 
264
  for info in llm_infos:
265
- LLMService.save(**info)
 
 
 
266
 
267
 
268
  def init_web_data():
269
  start_time = time.time()
270
 
271
- if not LLMService.get_all().count():init_llm_factory()
 
272
  if not UserService.get_all().count():
273
  init_superuser()
274
 
275
  print("init web data success:{}".format(time.time() - start_time))
276
 
277
 
 
278
  if __name__ == '__main__':
279
  init_web_db()
280
- init_web_data()
 
 
73
  print("\33[91m【ERROR】\33[0m:", " '{}' dosen't work!".format(tenant["embd_id"]))
74
 
75
 
76
+ factory_infos = [{
77
+ "name": "OpenAI",
78
+ "logo": "",
79
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
80
+ "status": "1",
81
+ },{
82
+ "name": "通义千问",
83
+ "logo": "",
84
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
85
+ "status": "1",
86
+ },{
87
+ "name": "智谱AI",
88
+ "logo": "",
89
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
90
+ "status": "1",
91
+ },
92
+ {
93
+ "name": "Local",
94
+ "logo": "",
95
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
96
+ "status": "1",
97
+ },{
98
+ "name": "Moonshot",
99
+ "logo": "",
100
+ "tags": "LLM,TEXT EMBEDDING",
101
+ "status": "1",
102
+ }
103
+ # {
104
+ # "name": "文心一言",
105
+ # "logo": "",
106
+ # "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
107
+ # "status": "1",
108
+ # },
109
+ ]
110
  def init_llm_factory():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  llm_infos = [
112
  # ---------------------- OpenAI ------------------------
113
  {
 
260
  },
261
  ]
262
  for info in factory_infos:
263
+ try:
264
+ LLMFactoriesService.save(**info)
265
+ except Exception as e:
266
+ pass
267
  for info in llm_infos:
268
+ try:
269
+ LLMService.save(**info)
270
+ except Exception as e:
271
+ pass
272
 
273
 
274
  def init_web_data():
275
  start_time = time.time()
276
 
277
+ if LLMFactoriesService.get_all().count() != len(factory_infos):
278
+ init_llm_factory()
279
  if not UserService.get_all().count():
280
  init_superuser()
281
 
282
  print("init web data success:{}".format(time.time() - start_time))
283
 
284
 
285
+
286
  if __name__ == '__main__':
287
  init_web_db()
288
+ init_web_data()
289
+ add_tenant_llm()
api/db/services/llm_service.py CHANGED
@@ -53,7 +53,7 @@ class TenantLLMService(CommonService):
53
  cls.model.used_tokens
54
  ]
55
  objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
56
- cls.model.tenant_id == tenant_id).dicts()
57
 
58
  return list(objs)
59
 
 
53
  cls.model.used_tokens
54
  ]
55
  objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
56
+ cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
57
 
58
  return list(objs)
59
 
rag/llm/chat_model.py CHANGED
@@ -54,6 +54,21 @@ class MoonshotChat(GptTurbo):
54
  self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",)
55
  self.model_name = model_name
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  from dashscope import Generation
59
  class QWenChat(Base):
 
54
  self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",)
55
  self.model_name = model_name
56
 
57
+ def chat(self, system, history, gen_conf):
58
+ if system: history.insert(0, {"role": "system", "content": system})
59
+ try:
60
+ response = self.client.chat.completions.create(
61
+ model=self.model_name,
62
+ messages=history,
63
+ **gen_conf)
64
+ ans = response.choices[0].message.content.strip()
65
+ if response.choices[0].finish_reason == "length":
66
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
67
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
68
+ return ans, response.usage.completion_tokens
69
+ except openai.APIError as e:
70
+ return "**ERROR**: "+str(e), 0
71
+
72
 
73
  from dashscope import Generation
74
  class QWenChat(Base):