bupa1018 commited on
Commit
25830df
·
1 Parent(s): 9b72c2e

Update embeddings.py

Browse files
Files changed (1) hide show
  1. embeddings.py +18 -0
embeddings.py CHANGED
@@ -9,4 +9,22 @@ def get_hf_embeddings(model_name=None):
9
 
10
  embeddings = HuggingFaceEmbeddings(model_name=model_name)
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  return embeddings
 
9
 
10
  embeddings = HuggingFaceEmbeddings(model_name=model_name)
11
 
12
+ return embeddings
13
+
14
+ def get_SFR_Code_embedding_model(
15
+ model_name="Salesforce/SFR-Embedding-Code-400M_R", device="auto"
16
+ ):
17
+ """Get jinaai embedding."""
18
+
19
+ # device: cpu or cuda
20
+ if device == "auto":
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ model_name = model_name
24
+ model_kwargs = {"device": device, "trust_remote_code": True}
25
+ embeddings = HuggingFaceEmbeddings(
26
+ model_name=model_name,
27
+ model_kwargs=model_kwargs,
28
+ )
29
+
30
  return embeddings