Spaces:
Sleeping
Sleeping
# Databricks notebook source | |
from typing import List,Optional | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings.base import Embeddings | |
from langchain_community.vectorstores.utils import DistanceStrategy | |
from transformers import RagRetriever | |
from langchain.docstore.document import Document as LangchainDocument | |
def init_vectorDB_from_doc(documents:List[LangchainDocument], embedding_model: Embeddings) -> FAISS: | |
KNOWLEDGE_VECTOR_DATABASE = FAISS.from_documents( | |
documents, embedding_model, distance_strategy=DistanceStrategy.COSINE | |
) | |
return KNOWLEDGE_VECTOR_DATABASE | |
def retriever( | |
user_query: str, | |
vectorDB: FAISS, | |
reranker = None, | |
num_doc_before_rerank: int = 5, | |
num_final_relevant_docs: int = 5, | |
rerank: bool = True | |
) -> List[str]: | |
relevant_docs = vectorDB.similarity_search(query=user_query, k=num_doc_before_rerank) | |
relevant_docs = [doc.page_content for doc in relevant_docs] # Keep only the text | |
print("=> Relevant documents:") | |
print(relevant_docs) | |
if rerank and reranker: | |
# Reranking documents | |
relevant_docs = reranker.rerank(user_query, relevant_docs, k=num_final_relevant_docs) | |
final_relevant_docs = [doc["content"] for doc in relevant_docs] | |
print("=> Reranked documents:") | |
print(final_relevant_docs) | |
else: | |
final_relevant_docs = relevant_docs | |
print("=> Final relevant documents:") | |
print(final_relevant_docs) | |
return final_relevant_docs | |