RAG_PEDIATRICS / src /embeddings.py
Stéphanie Kamgnia Wonkap
initial commit
a6e92fe
raw
history blame
534 Bytes
# Databricks notebook source
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
def init_embedding_model(EMBEDDING_MODEL_NAME: str):
embedding_model = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
multi_process=True,
model_kwargs={"device": "cuda"},
# model_kwargs={"device": "cpu"},
# Set `True` for cosine similarity
encode_kwargs={"normalize_embeddings": True},
)
return embedding_model