RAG_PEDIATRICS / src /retriever.py
Stéphanie Kamgnia Wonkap
initial commit
a6e92fe
raw
history blame
1.51 kB
# Databricks notebook source
from typing import List,Optional
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from langchain_community.vectorstores.utils import DistanceStrategy
from transformers import RagRetriever
from langchain.docstore.document import Document as LangchainDocument
def init_vectorDB_from_doc(documents:List[LangchainDocument], embedding_model: Embeddings) -> FAISS:
KNOWLEDGE_VECTOR_DATABASE = FAISS.from_documents(
documents, embedding_model, distance_strategy=DistanceStrategy.COSINE
)
return KNOWLEDGE_VECTOR_DATABASE
def retriever(
user_query: str,
vectorDB: FAISS,
reranker = None,
num_doc_before_rerank: int = 5,
num_final_relevant_docs: int = 5,
rerank: bool = True
) -> List[str]:
relevant_docs = vectorDB.similarity_search(query=user_query, k=num_doc_before_rerank)
relevant_docs = [doc.page_content for doc in relevant_docs] # Keep only the text
print("=> Relevant documents:")
print(relevant_docs)
if rerank and reranker:
# Reranking documents
relevant_docs = reranker.rerank(user_query, relevant_docs, k=num_final_relevant_docs)
final_relevant_docs = [doc["content"] for doc in relevant_docs]
print("=> Reranked documents:")
print(final_relevant_docs)
else:
final_relevant_docs = relevant_docs
print("=> Final relevant documents:")
print(final_relevant_docs)
return final_relevant_docs