File size: 1,505 Bytes
a6e92fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Databricks notebook source
from typing import List,Optional
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings 
from langchain_community.vectorstores.utils import DistanceStrategy
from transformers import RagRetriever
from langchain.docstore.document import Document as LangchainDocument

def init_vectorDB_from_doc(documents:List[LangchainDocument], embedding_model: Embeddings) -> FAISS:
    KNOWLEDGE_VECTOR_DATABASE = FAISS.from_documents(
        documents, embedding_model, distance_strategy=DistanceStrategy.COSINE
    )
    return KNOWLEDGE_VECTOR_DATABASE





def retriever(
    user_query: str,
    vectorDB: FAISS,
    reranker = None,
    num_doc_before_rerank: int = 5,
    num_final_relevant_docs: int = 5,
    rerank: bool = True
) -> List[str]:
    relevant_docs = vectorDB.similarity_search(query=user_query, k=num_doc_before_rerank)
    relevant_docs = [doc.page_content for doc in relevant_docs]  # Keep only the text
    print("=> Relevant documents:")
    print(relevant_docs)
    if rerank and reranker:
        # Reranking documents
        relevant_docs = reranker.rerank(user_query, relevant_docs, k=num_final_relevant_docs)
        final_relevant_docs = [doc["content"] for doc in relevant_docs]
        print("=> Reranked documents:")
        print(final_relevant_docs)
    else:
        final_relevant_docs = relevant_docs
        print("=> Final relevant documents:")
        print(final_relevant_docs)
    return final_relevant_docs