KadiAPY_Coding_Assistant / embeddings.py
bupa1018's picture
Update embeddings.py
ffc7b61
raw
history blame
886 Bytes
from langchain.embeddings import HuggingFaceEmbeddings
import torch
def get_hf_embeddings(model_name=None):
"""Retrieve a Hugging Face embedding model using the specified model name.."""
if model_name is None:
# "sentence-transformers/all-mpnet-base-v2"
model_name = "BAAI/bge-base-en-v1.5"
embeddings = HuggingFaceEmbeddings(model_name=model_name)
return embeddings
def get_SFR_Code_embedding_model(
model_name="Salesforce/SFR-Embedding-Code-400M_R", device="auto"
):
"""Get jinaai embedding."""
# device: cpu or cuda
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = model_name
model_kwargs = {"device": device, "trust_remote_code": True}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
)
return embeddings