refactor some llm api using openai api format (#1692)
Browse files### What problem does this PR solve?
refactor some llm api using openai api format
### Type of change
- [x] Refactoring
---------
Co-authored-by: Zhedong Cen <[email protected]>
- rag/llm/chat_model.py +28 -150
- rag/llm/cv_model.py +15 -67
- rag/llm/embedding_model.py +15 -23
rag/llm/chat_model.py
CHANGED
@@ -24,6 +24,7 @@ from volcengine.maas.v2 import MaasService
|
|
24 |
from rag.nlp import is_english
|
25 |
from rag.utils import num_tokens_from_string
|
26 |
from groq import Groq
|
|
|
27 |
import json
|
28 |
import requests
|
29 |
|
@@ -60,9 +61,16 @@ class Base(ABC):
|
|
60 |
stream=True,
|
61 |
**gen_conf)
|
62 |
for resp in response:
|
63 |
-
if not resp.choices
|
64 |
ans += resp.choices[0].delta.content
|
65 |
-
total_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
if resp.choices[0].finish_reason == "length":
|
67 |
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
68 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
@@ -85,8 +93,13 @@ class MoonshotChat(Base):
|
|
85 |
if not base_url: base_url="https://api.moonshot.cn/v1"
|
86 |
super().__init__(key, model_name, base_url)
|
87 |
|
|
|
88 |
class XinferenceChat(Base):
|
89 |
def __init__(self, key=None, model_name="", base_url=""):
|
|
|
|
|
|
|
|
|
90 |
key = "xxx"
|
91 |
super().__init__(key, model_name, base_url)
|
92 |
|
@@ -349,79 +362,13 @@ class OllamaChat(Base):
|
|
349 |
|
350 |
class LocalAIChat(Base):
|
351 |
def __init__(self, key, model_name, base_url):
|
352 |
-
if base_url
|
353 |
-
|
354 |
-
|
|
|
|
|
355 |
self.model_name = model_name.split("___")[0]
|
356 |
|
357 |
-
def chat(self, system, history, gen_conf):
|
358 |
-
if system:
|
359 |
-
history.insert(0, {"role": "system", "content": system})
|
360 |
-
for k in list(gen_conf.keys()):
|
361 |
-
if k not in ["temperature", "top_p", "max_tokens"]:
|
362 |
-
del gen_conf[k]
|
363 |
-
headers = {
|
364 |
-
"Content-Type": "application/json",
|
365 |
-
}
|
366 |
-
payload = json.dumps(
|
367 |
-
{"model": self.model_name, "messages": history, **gen_conf}
|
368 |
-
)
|
369 |
-
try:
|
370 |
-
response = requests.request(
|
371 |
-
"POST", url=self.base_url, headers=headers, data=payload
|
372 |
-
)
|
373 |
-
response = response.json()
|
374 |
-
ans = response["choices"][0]["message"]["content"].strip()
|
375 |
-
if response["choices"][0]["finish_reason"] == "length":
|
376 |
-
ans += (
|
377 |
-
"...\nFor the content length reason, it stopped, continue?"
|
378 |
-
if is_english([ans])
|
379 |
-
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
380 |
-
)
|
381 |
-
return ans, response["usage"]["total_tokens"]
|
382 |
-
except Exception as e:
|
383 |
-
return "**ERROR**: " + str(e), 0
|
384 |
-
|
385 |
-
def chat_streamly(self, system, history, gen_conf):
|
386 |
-
if system:
|
387 |
-
history.insert(0, {"role": "system", "content": system})
|
388 |
-
ans = ""
|
389 |
-
total_tokens = 0
|
390 |
-
try:
|
391 |
-
headers = {
|
392 |
-
"Content-Type": "application/json",
|
393 |
-
}
|
394 |
-
payload = json.dumps(
|
395 |
-
{
|
396 |
-
"model": self.model_name,
|
397 |
-
"messages": history,
|
398 |
-
"stream": True,
|
399 |
-
**gen_conf,
|
400 |
-
}
|
401 |
-
)
|
402 |
-
response = requests.request(
|
403 |
-
"POST",
|
404 |
-
url=self.base_url,
|
405 |
-
headers=headers,
|
406 |
-
data=payload,
|
407 |
-
)
|
408 |
-
for resp in response.content.decode("utf-8").split("\n\n"):
|
409 |
-
if "choices" not in resp:
|
410 |
-
continue
|
411 |
-
resp = json.loads(resp[6:])
|
412 |
-
if "delta" in resp["choices"][0]:
|
413 |
-
text = resp["choices"][0]["delta"]["content"]
|
414 |
-
else:
|
415 |
-
continue
|
416 |
-
ans += text
|
417 |
-
total_tokens += 1
|
418 |
-
yield ans
|
419 |
-
|
420 |
-
except Exception as e:
|
421 |
-
yield ans + "\n**ERROR**: " + str(e)
|
422 |
-
|
423 |
-
yield total_tokens
|
424 |
-
|
425 |
|
426 |
class LocalLLM(Base):
|
427 |
class RPCProxy:
|
@@ -892,9 +839,10 @@ class GroqChat:
|
|
892 |
## openrouter
|
893 |
class OpenRouterChat(Base):
|
894 |
def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"):
|
895 |
-
|
896 |
-
|
897 |
-
|
|
|
898 |
|
899 |
class StepFunChat(Base):
|
900 |
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"):
|
@@ -904,87 +852,17 @@ class StepFunChat(Base):
|
|
904 |
|
905 |
|
906 |
class NvidiaChat(Base):
|
907 |
-
def __init__(
|
908 |
-
self,
|
909 |
-
key,
|
910 |
-
model_name,
|
911 |
-
base_url="https://integrate.api.nvidia.com/v1/chat/completions",
|
912 |
-
):
|
913 |
if not base_url:
|
914 |
-
base_url = "https://integrate.api.nvidia.com/v1
|
915 |
-
|
916 |
-
self.model_name = model_name
|
917 |
-
self.api_key = key
|
918 |
-
self.headers = {
|
919 |
-
"accept": "application/json",
|
920 |
-
"Authorization": f"Bearer {self.api_key}",
|
921 |
-
"Content-Type": "application/json",
|
922 |
-
}
|
923 |
-
|
924 |
-
def chat(self, system, history, gen_conf):
|
925 |
-
if system:
|
926 |
-
history.insert(0, {"role": "system", "content": system})
|
927 |
-
for k in list(gen_conf.keys()):
|
928 |
-
if k not in ["temperature", "top_p", "max_tokens"]:
|
929 |
-
del gen_conf[k]
|
930 |
-
payload = {"model": self.model_name, "messages": history, **gen_conf}
|
931 |
-
try:
|
932 |
-
response = requests.post(
|
933 |
-
url=self.base_url, headers=self.headers, json=payload
|
934 |
-
)
|
935 |
-
response = response.json()
|
936 |
-
ans = response["choices"][0]["message"]["content"].strip()
|
937 |
-
return ans, response["usage"]["total_tokens"]
|
938 |
-
except Exception as e:
|
939 |
-
return "**ERROR**: " + str(e), 0
|
940 |
-
|
941 |
-
def chat_streamly(self, system, history, gen_conf):
|
942 |
-
if system:
|
943 |
-
history.insert(0, {"role": "system", "content": system})
|
944 |
-
for k in list(gen_conf.keys()):
|
945 |
-
if k not in ["temperature", "top_p", "max_tokens"]:
|
946 |
-
del gen_conf[k]
|
947 |
-
ans = ""
|
948 |
-
total_tokens = 0
|
949 |
-
payload = {
|
950 |
-
"model": self.model_name,
|
951 |
-
"messages": history,
|
952 |
-
"stream": True,
|
953 |
-
**gen_conf,
|
954 |
-
}
|
955 |
-
|
956 |
-
try:
|
957 |
-
response = requests.post(
|
958 |
-
url=self.base_url,
|
959 |
-
headers=self.headers,
|
960 |
-
json=payload,
|
961 |
-
)
|
962 |
-
for resp in response.text.split("\n\n"):
|
963 |
-
if "choices" not in resp:
|
964 |
-
continue
|
965 |
-
resp = json.loads(resp[6:])
|
966 |
-
if "content" in resp["choices"][0]["delta"]:
|
967 |
-
text = resp["choices"][0]["delta"]["content"]
|
968 |
-
else:
|
969 |
-
continue
|
970 |
-
ans += text
|
971 |
-
if "usage" in resp:
|
972 |
-
total_tokens = resp["usage"]["total_tokens"]
|
973 |
-
yield ans
|
974 |
-
|
975 |
-
except Exception as e:
|
976 |
-
yield ans + "\n**ERROR**: " + str(e)
|
977 |
-
|
978 |
-
yield total_tokens
|
979 |
|
980 |
|
981 |
class LmStudioChat(Base):
|
982 |
def __init__(self, key, model_name, base_url):
|
983 |
-
from os.path import join
|
984 |
-
|
985 |
if not base_url:
|
986 |
raise ValueError("Local llm url cannot be None")
|
987 |
if base_url.split("/")[-1] != "v1":
|
988 |
-
self.base_url = join(base_url, "v1")
|
989 |
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
990 |
self.model_name = model_name
|
|
|
24 |
from rag.nlp import is_english
|
25 |
from rag.utils import num_tokens_from_string
|
26 |
from groq import Groq
|
27 |
+
import os
|
28 |
import json
|
29 |
import requests
|
30 |
|
|
|
61 |
stream=True,
|
62 |
**gen_conf)
|
63 |
for resp in response:
|
64 |
+
if not resp.choices:continue
|
65 |
ans += resp.choices[0].delta.content
|
66 |
+
total_tokens = (
|
67 |
+
(
|
68 |
+
total_tokens
|
69 |
+
+ num_tokens_from_string(resp.choices[0].delta.content)
|
70 |
+
)
|
71 |
+
if not hasattr(resp, "usage")
|
72 |
+
else resp.usage["total_tokens"]
|
73 |
+
)
|
74 |
if resp.choices[0].finish_reason == "length":
|
75 |
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
76 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
|
|
93 |
if not base_url: base_url="https://api.moonshot.cn/v1"
|
94 |
super().__init__(key, model_name, base_url)
|
95 |
|
96 |
+
|
97 |
class XinferenceChat(Base):
|
98 |
def __init__(self, key=None, model_name="", base_url=""):
|
99 |
+
if not base_url:
|
100 |
+
raise ValueError("Local llm url cannot be None")
|
101 |
+
if base_url.split("/")[-1] != "v1":
|
102 |
+
self.base_url = os.path.join(base_url, "v1")
|
103 |
key = "xxx"
|
104 |
super().__init__(key, model_name, base_url)
|
105 |
|
|
|
362 |
|
363 |
class LocalAIChat(Base):
|
364 |
def __init__(self, key, model_name, base_url):
|
365 |
+
if not base_url:
|
366 |
+
raise ValueError("Local llm url cannot be None")
|
367 |
+
if base_url.split("/")[-1] != "v1":
|
368 |
+
self.base_url = os.path.join(base_url, "v1")
|
369 |
+
self.client = OpenAI(api_key="empty", base_url=self.base_url)
|
370 |
self.model_name = model_name.split("___")[0]
|
371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
|
373 |
class LocalLLM(Base):
|
374 |
class RPCProxy:
|
|
|
839 |
## openrouter
|
840 |
class OpenRouterChat(Base):
|
841 |
def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"):
|
842 |
+
if not base_url:
|
843 |
+
base_url = "https://openrouter.ai/api/v1"
|
844 |
+
super().__init__(key, model_name, base_url)
|
845 |
+
|
846 |
|
847 |
class StepFunChat(Base):
|
848 |
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"):
|
|
|
852 |
|
853 |
|
854 |
class NvidiaChat(Base):
|
855 |
+
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1"):
|
|
|
|
|
|
|
|
|
|
|
856 |
if not base_url:
|
857 |
+
base_url = "https://integrate.api.nvidia.com/v1"
|
858 |
+
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
859 |
|
860 |
|
861 |
class LmStudioChat(Base):
|
862 |
def __init__(self, key, model_name, base_url):
|
|
|
|
|
863 |
if not base_url:
|
864 |
raise ValueError("Local llm url cannot be None")
|
865 |
if base_url.split("/")[-1] != "v1":
|
866 |
+
self.base_url = os.path.join(base_url, "v1")
|
867 |
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
868 |
self.model_name = model_name
|
rag/llm/cv_model.py
CHANGED
@@ -378,7 +378,7 @@ class OllamaCV(Base):
|
|
378 |
def chat(self, system, history, gen_conf, image=""):
|
379 |
if system:
|
380 |
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
381 |
-
|
382 |
try:
|
383 |
for his in history:
|
384 |
if his["role"] == "user":
|
@@ -433,27 +433,16 @@ class OllamaCV(Base):
|
|
433 |
yield 0
|
434 |
|
435 |
|
436 |
-
class LocalAICV(
|
437 |
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
|
|
|
|
|
|
|
|
438 |
self.client = OpenAI(api_key="empty", base_url=base_url)
|
439 |
self.model_name = model_name.split("___")[0]
|
440 |
self.lang = lang
|
441 |
|
442 |
-
def describe(self, image, max_tokens=300):
|
443 |
-
b64 = self.image2base64(image)
|
444 |
-
prompt = self.prompt(b64)
|
445 |
-
for i in range(len(prompt)):
|
446 |
-
for c in prompt[i]["content"]:
|
447 |
-
if "text" in c:
|
448 |
-
c["type"] = "text"
|
449 |
-
|
450 |
-
res = self.client.chat.completions.create(
|
451 |
-
model=self.model_name,
|
452 |
-
messages=prompt,
|
453 |
-
max_tokens=max_tokens,
|
454 |
-
)
|
455 |
-
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
456 |
-
|
457 |
|
458 |
class XinferenceCV(Base):
|
459 |
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
@@ -549,60 +538,19 @@ class GeminiCV(Base):
|
|
549 |
yield response._chunks[-1].usage_metadata.total_token_count
|
550 |
|
551 |
|
552 |
-
class OpenRouterCV(
|
553 |
def __init__(
|
554 |
self,
|
555 |
key,
|
556 |
model_name,
|
557 |
lang="Chinese",
|
558 |
-
base_url="https://openrouter.ai/api/v1
|
559 |
):
|
|
|
|
|
|
|
560 |
self.model_name = model_name
|
561 |
self.lang = lang
|
562 |
-
self.base_url = "https://openrouter.ai/api/v1/chat/completions"
|
563 |
-
self.key = key
|
564 |
-
|
565 |
-
def describe(self, image, max_tokens=300):
|
566 |
-
b64 = self.image2base64(image)
|
567 |
-
response = requests.post(
|
568 |
-
url=self.base_url,
|
569 |
-
headers={
|
570 |
-
"Authorization": f"Bearer {self.key}",
|
571 |
-
},
|
572 |
-
data=json.dumps(
|
573 |
-
{
|
574 |
-
"model": self.model_name,
|
575 |
-
"messages": self.prompt(b64),
|
576 |
-
"max_tokens": max_tokens,
|
577 |
-
}
|
578 |
-
),
|
579 |
-
)
|
580 |
-
response = response.json()
|
581 |
-
return (
|
582 |
-
response["choices"][0]["message"]["content"].strip(),
|
583 |
-
response["usage"]["total_tokens"],
|
584 |
-
)
|
585 |
-
|
586 |
-
def prompt(self, b64):
|
587 |
-
return [
|
588 |
-
{
|
589 |
-
"role": "user",
|
590 |
-
"content": [
|
591 |
-
{
|
592 |
-
"type": "image_url",
|
593 |
-
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
|
594 |
-
},
|
595 |
-
{
|
596 |
-
"type": "text",
|
597 |
-
"text": (
|
598 |
-
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
599 |
-
if self.lang.lower() == "chinese"
|
600 |
-
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
601 |
-
),
|
602 |
-
},
|
603 |
-
],
|
604 |
-
}
|
605 |
-
]
|
606 |
|
607 |
|
608 |
class LocalCV(Base):
|
@@ -675,12 +623,12 @@ class NvidiaCV(Base):
|
|
675 |
]
|
676 |
|
677 |
|
678 |
-
class LmStudioCV(
|
679 |
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
680 |
if not base_url:
|
681 |
raise ValueError("Local llm url cannot be None")
|
682 |
-
if base_url.split(
|
683 |
-
|
684 |
-
self.client = OpenAI(api_key="lm-studio", base_url=
|
685 |
self.model_name = model_name
|
686 |
self.lang = lang
|
|
|
378 |
def chat(self, system, history, gen_conf, image=""):
|
379 |
if system:
|
380 |
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
381 |
+
|
382 |
try:
|
383 |
for his in history:
|
384 |
if his["role"] == "user":
|
|
|
433 |
yield 0
|
434 |
|
435 |
|
436 |
+
class LocalAICV(GptV4):
|
437 |
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
438 |
+
if not base_url:
|
439 |
+
raise ValueError("Local cv model url cannot be None")
|
440 |
+
if base_url.split("/")[-1] != "v1":
|
441 |
+
base_url = os.path.join(base_url, "v1")
|
442 |
self.client = OpenAI(api_key="empty", base_url=base_url)
|
443 |
self.model_name = model_name.split("___")[0]
|
444 |
self.lang = lang
|
445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
|
447 |
class XinferenceCV(Base):
|
448 |
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
|
|
538 |
yield response._chunks[-1].usage_metadata.total_token_count
|
539 |
|
540 |
|
541 |
+
class OpenRouterCV(GptV4):
|
542 |
def __init__(
|
543 |
self,
|
544 |
key,
|
545 |
model_name,
|
546 |
lang="Chinese",
|
547 |
+
base_url="https://openrouter.ai/api/v1",
|
548 |
):
|
549 |
+
if not base_url:
|
550 |
+
base_url = "https://openrouter.ai/api/v1"
|
551 |
+
self.client = OpenAI(api_key=key, base_url=base_url)
|
552 |
self.model_name = model_name
|
553 |
self.lang = lang
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
554 |
|
555 |
|
556 |
class LocalCV(Base):
|
|
|
623 |
]
|
624 |
|
625 |
|
626 |
+
class LmStudioCV(GptV4):
|
627 |
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
628 |
if not base_url:
|
629 |
raise ValueError("Local llm url cannot be None")
|
630 |
+
if base_url.split("/")[-1] != "v1":
|
631 |
+
base_url = os.path.join(base_url, "v1")
|
632 |
+
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
|
633 |
self.model_name = model_name
|
634 |
self.lang = lang
|
rag/llm/embedding_model.py
CHANGED
@@ -113,21 +113,24 @@ class OpenAIEmbed(Base):
|
|
113 |
|
114 |
class LocalAIEmbed(Base):
|
115 |
def __init__(self, key, model_name, base_url):
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
120 |
self.model_name = model_name.split("___")[0]
|
121 |
|
122 |
-
def encode(self, texts: list, batch_size=
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
127 |
|
128 |
def encode_queries(self, text):
|
129 |
-
|
130 |
-
return np.array(
|
|
|
131 |
|
132 |
class AzureEmbed(OpenAIEmbed):
|
133 |
def __init__(self, key, model_name, **kwargs):
|
@@ -502,7 +505,7 @@ class NvidiaEmbed(Base):
|
|
502 |
return np.array(embds[0]), cnt
|
503 |
|
504 |
|
505 |
-
class LmStudioEmbed(
|
506 |
def __init__(self, key, model_name, base_url):
|
507 |
if not base_url:
|
508 |
raise ValueError("Local llm url cannot be None")
|
@@ -510,14 +513,3 @@ class LmStudioEmbed(Base):
|
|
510 |
self.base_url = os.path.join(base_url, "v1")
|
511 |
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
512 |
self.model_name = model_name
|
513 |
-
|
514 |
-
def encode(self, texts: list, batch_size=32):
|
515 |
-
res = self.client.embeddings.create(input=texts, model=self.model_name)
|
516 |
-
return (
|
517 |
-
np.array([d.embedding for d in res.data]),
|
518 |
-
1024,
|
519 |
-
) # local embedding for LmStudio donot count tokens
|
520 |
-
|
521 |
-
def encode_queries(self, text):
|
522 |
-
res = self.client.embeddings.create(text, model=self.model_name)
|
523 |
-
return np.array(res.data[0].embedding), 1024
|
|
|
113 |
|
114 |
class LocalAIEmbed(Base):
|
115 |
def __init__(self, key, model_name, base_url):
|
116 |
+
if not base_url:
|
117 |
+
raise ValueError("Local embedding model url cannot be None")
|
118 |
+
if base_url.split("/")[-1] != "v1":
|
119 |
+
base_url = os.path.join(base_url, "v1")
|
120 |
+
self.client = OpenAI(api_key="empty", base_url=base_url)
|
121 |
self.model_name = model_name.split("___")[0]
|
122 |
|
123 |
+
def encode(self, texts: list, batch_size=32):
|
124 |
+
res = self.client.embeddings.create(input=texts, model=self.model_name)
|
125 |
+
return (
|
126 |
+
np.array([d.embedding for d in res.data]),
|
127 |
+
1024,
|
128 |
+
) # local embedding for LmStudio donot count tokens
|
129 |
|
130 |
def encode_queries(self, text):
|
131 |
+
res = self.client.embeddings.create(text, model=self.model_name)
|
132 |
+
return np.array(res.data[0].embedding), 1024
|
133 |
+
|
134 |
|
135 |
class AzureEmbed(OpenAIEmbed):
|
136 |
def __init__(self, key, model_name, **kwargs):
|
|
|
505 |
return np.array(embds[0]), cnt
|
506 |
|
507 |
|
508 |
+
class LmStudioEmbed(LocalAIEmbed):
|
509 |
def __init__(self, key, model_name, base_url):
|
510 |
if not base_url:
|
511 |
raise ValueError("Local llm url cannot be None")
|
|
|
513 |
self.base_url = os.path.join(base_url, "v1")
|
514 |
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
515 |
self.model_name = model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|