RAG_PEDIATRICS / src /embeddings.py
Stéphanie Kamgnia Wonkap
fixing main
b99886d
raw
history blame
543 Bytes
# Databricks notebook source
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
def init_embedding_model(EMBEDDING_MODEL_NAME: str):
embedding_model = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
multi_process=True,
model_kwargs={"device": "cuda"},
# model_kwargs={"device": "cpu"},
# Set `True` for cosine similarity
encode_kwargs={"normalize_embeddings": True},
)
return embedding_model