fix SILICONFLOW embedding error (#2363)
Browse files### What problem does this PR solve?
#2335 fix SILICONFLOW embedding error
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
---------
Co-authored-by: Zhedong Cen <[email protected]>
- rag/llm/embedding_model.py +33 -4
rag/llm/embedding_model.py
CHANGED
@@ -577,11 +577,40 @@ class UpstageEmbed(OpenAIEmbed):
|
|
577 |
super().__init__(key, model_name, base_url)
|
578 |
|
579 |
|
580 |
-
class SILICONFLOWEmbed(
|
581 |
-
def __init__(
|
|
|
|
|
582 |
if not base_url:
|
583 |
-
base_url = "https://api.siliconflow.cn/v1"
|
584 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
|
586 |
|
587 |
class ReplicateEmbed(Base):
|
|
|
577 |
super().__init__(key, model_name, base_url)
|
578 |
|
579 |
|
580 |
+
class SILICONFLOWEmbed(Base):
|
581 |
+
def __init__(
|
582 |
+
self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"
|
583 |
+
):
|
584 |
if not base_url:
|
585 |
+
base_url = "https://api.siliconflow.cn/v1/embeddings"
|
586 |
+
self.headers = {
|
587 |
+
"accept": "application/json",
|
588 |
+
"content-type": "application/json",
|
589 |
+
"authorization": f"Bearer {key}",
|
590 |
+
}
|
591 |
+
self.base_url = base_url
|
592 |
+
self.model_name = model_name
|
593 |
+
|
594 |
+
def encode(self, texts: list, batch_size=32):
|
595 |
+
payload = {
|
596 |
+
"model": self.model_name,
|
597 |
+
"input": texts,
|
598 |
+
"encoding_format": "float",
|
599 |
+
}
|
600 |
+
res = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
601 |
+
return (
|
602 |
+
np.array([d["embedding"] for d in res["data"]]),
|
603 |
+
res["usage"]["total_tokens"],
|
604 |
+
)
|
605 |
+
|
606 |
+
def encode_queries(self, text):
|
607 |
+
payload = {
|
608 |
+
"model": self.model_name,
|
609 |
+
"input": text,
|
610 |
+
"encoding_format": "float",
|
611 |
+
}
|
612 |
+
res = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
613 |
+
return np.array(res["data"][0]["embedding"]), res["usage"]["total_tokens"]
|
614 |
|
615 |
|
616 |
class ReplicateEmbed(Base):
|