黄腾 aopstudio commited on
Commit
6ead7eb
·
1 Parent(s): fd6a873

refactor some llm api using openai api format (#1692)

Browse files

### What problem does this PR solve?

refactor some llm api using openai api format

### Type of change

- [x] Refactoring

---------

Co-authored-by: Zhedong Cen <[email protected]>

rag/llm/chat_model.py CHANGED
@@ -24,6 +24,7 @@ from volcengine.maas.v2 import MaasService
24
  from rag.nlp import is_english
25
  from rag.utils import num_tokens_from_string
26
  from groq import Groq
 
27
  import json
28
  import requests
29
 
@@ -60,9 +61,16 @@ class Base(ABC):
60
  stream=True,
61
  **gen_conf)
62
  for resp in response:
63
- if not resp.choices or not resp.choices[0].delta.content:continue
64
  ans += resp.choices[0].delta.content
65
- total_tokens += 1
 
 
 
 
 
 
 
66
  if resp.choices[0].finish_reason == "length":
67
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
68
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
@@ -85,8 +93,13 @@ class MoonshotChat(Base):
85
  if not base_url: base_url="https://api.moonshot.cn/v1"
86
  super().__init__(key, model_name, base_url)
87
 
 
88
  class XinferenceChat(Base):
89
  def __init__(self, key=None, model_name="", base_url=""):
 
 
 
 
90
  key = "xxx"
91
  super().__init__(key, model_name, base_url)
92
 
@@ -349,79 +362,13 @@ class OllamaChat(Base):
349
 
350
  class LocalAIChat(Base):
351
  def __init__(self, key, model_name, base_url):
352
- if base_url[-1] == "/":
353
- base_url = base_url[:-1]
354
- self.base_url = base_url + "/v1/chat/completions"
 
 
355
  self.model_name = model_name.split("___")[0]
356
 
357
- def chat(self, system, history, gen_conf):
358
- if system:
359
- history.insert(0, {"role": "system", "content": system})
360
- for k in list(gen_conf.keys()):
361
- if k not in ["temperature", "top_p", "max_tokens"]:
362
- del gen_conf[k]
363
- headers = {
364
- "Content-Type": "application/json",
365
- }
366
- payload = json.dumps(
367
- {"model": self.model_name, "messages": history, **gen_conf}
368
- )
369
- try:
370
- response = requests.request(
371
- "POST", url=self.base_url, headers=headers, data=payload
372
- )
373
- response = response.json()
374
- ans = response["choices"][0]["message"]["content"].strip()
375
- if response["choices"][0]["finish_reason"] == "length":
376
- ans += (
377
- "...\nFor the content length reason, it stopped, continue?"
378
- if is_english([ans])
379
- else "······\n由于长度的原因,回答被截断了,要继续吗?"
380
- )
381
- return ans, response["usage"]["total_tokens"]
382
- except Exception as e:
383
- return "**ERROR**: " + str(e), 0
384
-
385
- def chat_streamly(self, system, history, gen_conf):
386
- if system:
387
- history.insert(0, {"role": "system", "content": system})
388
- ans = ""
389
- total_tokens = 0
390
- try:
391
- headers = {
392
- "Content-Type": "application/json",
393
- }
394
- payload = json.dumps(
395
- {
396
- "model": self.model_name,
397
- "messages": history,
398
- "stream": True,
399
- **gen_conf,
400
- }
401
- )
402
- response = requests.request(
403
- "POST",
404
- url=self.base_url,
405
- headers=headers,
406
- data=payload,
407
- )
408
- for resp in response.content.decode("utf-8").split("\n\n"):
409
- if "choices" not in resp:
410
- continue
411
- resp = json.loads(resp[6:])
412
- if "delta" in resp["choices"][0]:
413
- text = resp["choices"][0]["delta"]["content"]
414
- else:
415
- continue
416
- ans += text
417
- total_tokens += 1
418
- yield ans
419
-
420
- except Exception as e:
421
- yield ans + "\n**ERROR**: " + str(e)
422
-
423
- yield total_tokens
424
-
425
 
426
  class LocalLLM(Base):
427
  class RPCProxy:
@@ -892,9 +839,10 @@ class GroqChat:
892
  ## openrouter
893
  class OpenRouterChat(Base):
894
  def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"):
895
- self.base_url = "https://openrouter.ai/api/v1"
896
- self.client = OpenAI(base_url=self.base_url, api_key=key)
897
- self.model_name = model_name
 
898
 
899
  class StepFunChat(Base):
900
  def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"):
@@ -904,87 +852,17 @@ class StepFunChat(Base):
904
 
905
 
906
  class NvidiaChat(Base):
907
- def __init__(
908
- self,
909
- key,
910
- model_name,
911
- base_url="https://integrate.api.nvidia.com/v1/chat/completions",
912
- ):
913
  if not base_url:
914
- base_url = "https://integrate.api.nvidia.com/v1/chat/completions"
915
- self.base_url = base_url
916
- self.model_name = model_name
917
- self.api_key = key
918
- self.headers = {
919
- "accept": "application/json",
920
- "Authorization": f"Bearer {self.api_key}",
921
- "Content-Type": "application/json",
922
- }
923
-
924
- def chat(self, system, history, gen_conf):
925
- if system:
926
- history.insert(0, {"role": "system", "content": system})
927
- for k in list(gen_conf.keys()):
928
- if k not in ["temperature", "top_p", "max_tokens"]:
929
- del gen_conf[k]
930
- payload = {"model": self.model_name, "messages": history, **gen_conf}
931
- try:
932
- response = requests.post(
933
- url=self.base_url, headers=self.headers, json=payload
934
- )
935
- response = response.json()
936
- ans = response["choices"][0]["message"]["content"].strip()
937
- return ans, response["usage"]["total_tokens"]
938
- except Exception as e:
939
- return "**ERROR**: " + str(e), 0
940
-
941
- def chat_streamly(self, system, history, gen_conf):
942
- if system:
943
- history.insert(0, {"role": "system", "content": system})
944
- for k in list(gen_conf.keys()):
945
- if k not in ["temperature", "top_p", "max_tokens"]:
946
- del gen_conf[k]
947
- ans = ""
948
- total_tokens = 0
949
- payload = {
950
- "model": self.model_name,
951
- "messages": history,
952
- "stream": True,
953
- **gen_conf,
954
- }
955
-
956
- try:
957
- response = requests.post(
958
- url=self.base_url,
959
- headers=self.headers,
960
- json=payload,
961
- )
962
- for resp in response.text.split("\n\n"):
963
- if "choices" not in resp:
964
- continue
965
- resp = json.loads(resp[6:])
966
- if "content" in resp["choices"][0]["delta"]:
967
- text = resp["choices"][0]["delta"]["content"]
968
- else:
969
- continue
970
- ans += text
971
- if "usage" in resp:
972
- total_tokens = resp["usage"]["total_tokens"]
973
- yield ans
974
-
975
- except Exception as e:
976
- yield ans + "\n**ERROR**: " + str(e)
977
-
978
- yield total_tokens
979
 
980
 
981
  class LmStudioChat(Base):
982
  def __init__(self, key, model_name, base_url):
983
- from os.path import join
984
-
985
  if not base_url:
986
  raise ValueError("Local llm url cannot be None")
987
  if base_url.split("/")[-1] != "v1":
988
- self.base_url = join(base_url, "v1")
989
  self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
990
  self.model_name = model_name
 
24
  from rag.nlp import is_english
25
  from rag.utils import num_tokens_from_string
26
  from groq import Groq
27
+ import os
28
  import json
29
  import requests
30
 
 
61
  stream=True,
62
  **gen_conf)
63
  for resp in response:
64
+ if not resp.choices:continue
65
  ans += resp.choices[0].delta.content
66
+ total_tokens = (
67
+ (
68
+ total_tokens
69
+ + num_tokens_from_string(resp.choices[0].delta.content)
70
+ )
71
+ if not hasattr(resp, "usage")
72
+ else resp.usage["total_tokens"]
73
+ )
74
  if resp.choices[0].finish_reason == "length":
75
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
76
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
 
93
  if not base_url: base_url="https://api.moonshot.cn/v1"
94
  super().__init__(key, model_name, base_url)
95
 
96
+
97
  class XinferenceChat(Base):
98
  def __init__(self, key=None, model_name="", base_url=""):
99
+ if not base_url:
100
+ raise ValueError("Local llm url cannot be None")
101
+ if base_url.split("/")[-1] != "v1":
102
+ self.base_url = os.path.join(base_url, "v1")
103
  key = "xxx"
104
  super().__init__(key, model_name, base_url)
105
 
 
362
 
363
  class LocalAIChat(Base):
364
  def __init__(self, key, model_name, base_url):
365
+ if not base_url:
366
+ raise ValueError("Local llm url cannot be None")
367
+ if base_url.split("/")[-1] != "v1":
368
+ self.base_url = os.path.join(base_url, "v1")
369
+ self.client = OpenAI(api_key="empty", base_url=self.base_url)
370
  self.model_name = model_name.split("___")[0]
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
  class LocalLLM(Base):
374
  class RPCProxy:
 
839
  ## openrouter
840
  class OpenRouterChat(Base):
841
  def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"):
842
+ if not base_url:
843
+ base_url = "https://openrouter.ai/api/v1"
844
+ super().__init__(key, model_name, base_url)
845
+
846
 
847
  class StepFunChat(Base):
848
  def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"):
 
852
 
853
 
854
  class NvidiaChat(Base):
855
+ def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1"):
 
 
 
 
 
856
  if not base_url:
857
+ base_url = "https://integrate.api.nvidia.com/v1"
858
+ super().__init__(key, model_name, base_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
 
860
 
861
  class LmStudioChat(Base):
862
  def __init__(self, key, model_name, base_url):
 
 
863
  if not base_url:
864
  raise ValueError("Local llm url cannot be None")
865
  if base_url.split("/")[-1] != "v1":
866
+ self.base_url = os.path.join(base_url, "v1")
867
  self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
868
  self.model_name = model_name
rag/llm/cv_model.py CHANGED
@@ -378,7 +378,7 @@ class OllamaCV(Base):
378
  def chat(self, system, history, gen_conf, image=""):
379
  if system:
380
  history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
381
-
382
  try:
383
  for his in history:
384
  if his["role"] == "user":
@@ -433,27 +433,16 @@ class OllamaCV(Base):
433
  yield 0
434
 
435
 
436
- class LocalAICV(Base):
437
  def __init__(self, key, model_name, base_url, lang="Chinese"):
 
 
 
 
438
  self.client = OpenAI(api_key="empty", base_url=base_url)
439
  self.model_name = model_name.split("___")[0]
440
  self.lang = lang
441
 
442
- def describe(self, image, max_tokens=300):
443
- b64 = self.image2base64(image)
444
- prompt = self.prompt(b64)
445
- for i in range(len(prompt)):
446
- for c in prompt[i]["content"]:
447
- if "text" in c:
448
- c["type"] = "text"
449
-
450
- res = self.client.chat.completions.create(
451
- model=self.model_name,
452
- messages=prompt,
453
- max_tokens=max_tokens,
454
- )
455
- return res.choices[0].message.content.strip(), res.usage.total_tokens
456
-
457
 
458
  class XinferenceCV(Base):
459
  def __init__(self, key, model_name="", lang="Chinese", base_url=""):
@@ -549,60 +538,19 @@ class GeminiCV(Base):
549
  yield response._chunks[-1].usage_metadata.total_token_count
550
 
551
 
552
- class OpenRouterCV(Base):
553
  def __init__(
554
  self,
555
  key,
556
  model_name,
557
  lang="Chinese",
558
- base_url="https://openrouter.ai/api/v1/chat/completions",
559
  ):
 
 
 
560
  self.model_name = model_name
561
  self.lang = lang
562
- self.base_url = "https://openrouter.ai/api/v1/chat/completions"
563
- self.key = key
564
-
565
- def describe(self, image, max_tokens=300):
566
- b64 = self.image2base64(image)
567
- response = requests.post(
568
- url=self.base_url,
569
- headers={
570
- "Authorization": f"Bearer {self.key}",
571
- },
572
- data=json.dumps(
573
- {
574
- "model": self.model_name,
575
- "messages": self.prompt(b64),
576
- "max_tokens": max_tokens,
577
- }
578
- ),
579
- )
580
- response = response.json()
581
- return (
582
- response["choices"][0]["message"]["content"].strip(),
583
- response["usage"]["total_tokens"],
584
- )
585
-
586
- def prompt(self, b64):
587
- return [
588
- {
589
- "role": "user",
590
- "content": [
591
- {
592
- "type": "image_url",
593
- "image_url": {"url": f"data:image/jpeg;base64,{b64}"},
594
- },
595
- {
596
- "type": "text",
597
- "text": (
598
- "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
599
- if self.lang.lower() == "chinese"
600
- else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
601
- ),
602
- },
603
- ],
604
- }
605
- ]
606
 
607
 
608
  class LocalCV(Base):
@@ -675,12 +623,12 @@ class NvidiaCV(Base):
675
  ]
676
 
677
 
678
- class LmStudioCV(LocalAICV):
679
  def __init__(self, key, model_name, base_url, lang="Chinese"):
680
  if not base_url:
681
  raise ValueError("Local llm url cannot be None")
682
- if base_url.split('/')[-1] != 'v1':
683
- self.base_url = os.path.join(base_url,'v1')
684
- self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
685
  self.model_name = model_name
686
  self.lang = lang
 
378
  def chat(self, system, history, gen_conf, image=""):
379
  if system:
380
  history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
381
+
382
  try:
383
  for his in history:
384
  if his["role"] == "user":
 
433
  yield 0
434
 
435
 
436
+ class LocalAICV(GptV4):
437
  def __init__(self, key, model_name, base_url, lang="Chinese"):
438
+ if not base_url:
439
+ raise ValueError("Local cv model url cannot be None")
440
+ if base_url.split("/")[-1] != "v1":
441
+ base_url = os.path.join(base_url, "v1")
442
  self.client = OpenAI(api_key="empty", base_url=base_url)
443
  self.model_name = model_name.split("___")[0]
444
  self.lang = lang
445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
 
447
  class XinferenceCV(Base):
448
  def __init__(self, key, model_name="", lang="Chinese", base_url=""):
 
538
  yield response._chunks[-1].usage_metadata.total_token_count
539
 
540
 
541
+ class OpenRouterCV(GptV4):
542
  def __init__(
543
  self,
544
  key,
545
  model_name,
546
  lang="Chinese",
547
+ base_url="https://openrouter.ai/api/v1",
548
  ):
549
+ if not base_url:
550
+ base_url = "https://openrouter.ai/api/v1"
551
+ self.client = OpenAI(api_key=key, base_url=base_url)
552
  self.model_name = model_name
553
  self.lang = lang
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
 
555
 
556
  class LocalCV(Base):
 
623
  ]
624
 
625
 
626
+ class LmStudioCV(GptV4):
627
  def __init__(self, key, model_name, base_url, lang="Chinese"):
628
  if not base_url:
629
  raise ValueError("Local llm url cannot be None")
630
+ if base_url.split("/")[-1] != "v1":
631
+ base_url = os.path.join(base_url, "v1")
632
+ self.client = OpenAI(api_key="lm-studio", base_url=base_url)
633
  self.model_name = model_name
634
  self.lang = lang
rag/llm/embedding_model.py CHANGED
@@ -113,21 +113,24 @@ class OpenAIEmbed(Base):
113
 
114
  class LocalAIEmbed(Base):
115
  def __init__(self, key, model_name, base_url):
116
- self.base_url = base_url + "/embeddings"
117
- self.headers = {
118
- "Content-Type": "application/json",
119
- }
 
120
  self.model_name = model_name.split("___")[0]
121
 
122
- def encode(self, texts: list, batch_size=None):
123
- data = {"model": self.model_name, "input": texts, "encoding_type": "float"}
124
- res = requests.post(self.base_url, headers=self.headers, json=data).json()
125
-
126
- return np.array([d["embedding"] for d in res["data"]]), 1024
 
127
 
128
  def encode_queries(self, text):
129
- embds, cnt = self.encode([text])
130
- return np.array(embds[0]), cnt
 
131
 
132
  class AzureEmbed(OpenAIEmbed):
133
  def __init__(self, key, model_name, **kwargs):
@@ -502,7 +505,7 @@ class NvidiaEmbed(Base):
502
  return np.array(embds[0]), cnt
503
 
504
 
505
- class LmStudioEmbed(Base):
506
  def __init__(self, key, model_name, base_url):
507
  if not base_url:
508
  raise ValueError("Local llm url cannot be None")
@@ -510,14 +513,3 @@ class LmStudioEmbed(Base):
510
  self.base_url = os.path.join(base_url, "v1")
511
  self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
512
  self.model_name = model_name
513
-
514
- def encode(self, texts: list, batch_size=32):
515
- res = self.client.embeddings.create(input=texts, model=self.model_name)
516
- return (
517
- np.array([d.embedding for d in res.data]),
518
- 1024,
519
- ) # local embedding for LmStudio donot count tokens
520
-
521
- def encode_queries(self, text):
522
- res = self.client.embeddings.create(text, model=self.model_name)
523
- return np.array(res.data[0].embedding), 1024
 
113
 
114
  class LocalAIEmbed(Base):
115
  def __init__(self, key, model_name, base_url):
116
+ if not base_url:
117
+ raise ValueError("Local embedding model url cannot be None")
118
+ if base_url.split("/")[-1] != "v1":
119
+ base_url = os.path.join(base_url, "v1")
120
+ self.client = OpenAI(api_key="empty", base_url=base_url)
121
  self.model_name = model_name.split("___")[0]
122
 
123
+ def encode(self, texts: list, batch_size=32):
124
+ res = self.client.embeddings.create(input=texts, model=self.model_name)
125
+ return (
126
+ np.array([d.embedding for d in res.data]),
127
+ 1024,
128
+ ) # local embedding for LmStudio donot count tokens
129
 
130
  def encode_queries(self, text):
131
+ res = self.client.embeddings.create(text, model=self.model_name)
132
+ return np.array(res.data[0].embedding), 1024
133
+
134
 
135
  class AzureEmbed(OpenAIEmbed):
136
  def __init__(self, key, model_name, **kwargs):
 
505
  return np.array(embds[0]), cnt
506
 
507
 
508
+ class LmStudioEmbed(LocalAIEmbed):
509
  def __init__(self, key, model_name, base_url):
510
  if not base_url:
511
  raise ValueError("Local llm url cannot be None")
 
513
  self.base_url = os.path.join(base_url, "v1")
514
  self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
515
  self.model_name = model_name