黄腾 aopstudio commited on
Commit
16d7b7b
·
1 Parent(s): c2d189b

fix SILICONFLOW embedding error (#2363)

Browse files

### What problem does this PR solve?

#2335 fix SILICONFLOW embedding error

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Zhedong Cen <[email protected]>

Files changed (1) hide show
  1. rag/llm/embedding_model.py +33 -4
rag/llm/embedding_model.py CHANGED
@@ -577,11 +577,40 @@ class UpstageEmbed(OpenAIEmbed):
577
  super().__init__(key, model_name, base_url)
578
 
579
 
580
- class SILICONFLOWEmbed(OpenAIEmbed):
581
- def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
 
 
582
  if not base_url:
583
- base_url = "https://api.siliconflow.cn/v1"
584
- super().__init__(key, model_name, base_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
 
587
  class ReplicateEmbed(Base):
 
577
  super().__init__(key, model_name, base_url)
578
 
579
 
580
+ class SILICONFLOWEmbed(Base):
581
+ def __init__(
582
+ self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"
583
+ ):
584
  if not base_url:
585
+ base_url = "https://api.siliconflow.cn/v1/embeddings"
586
+ self.headers = {
587
+ "accept": "application/json",
588
+ "content-type": "application/json",
589
+ "authorization": f"Bearer {key}",
590
+ }
591
+ self.base_url = base_url
592
+ self.model_name = model_name
593
+
594
+ def encode(self, texts: list, batch_size=32):
595
+ payload = {
596
+ "model": self.model_name,
597
+ "input": texts,
598
+ "encoding_format": "float",
599
+ }
600
+ res = requests.post(self.base_url, json=payload, headers=self.headers).json()
601
+ return (
602
+ np.array([d["embedding"] for d in res["data"]]),
603
+ res["usage"]["total_tokens"],
604
+ )
605
+
606
+ def encode_queries(self, text):
607
+ payload = {
608
+ "model": self.model_name,
609
+ "input": text,
610
+ "encoding_format": "float",
611
+ }
612
+ res = requests.post(self.base_url, json=payload, headers=self.headers).json()
613
+ return np.array(res["data"][0]["embedding"]), res["usage"]["total_tokens"]
614
 
615
 
616
  class ReplicateEmbed(Base):