File size: 3,925 Bytes
20935cd
 
 
 
 
 
 
7f9546e
 
20935cd
7f9546e
20935cd
 
 
 
 
 
7f9546e
 
 
20935cd
 
 
7f9546e
20935cd
7f9546e
77616b1
20935cd
 
7f9546e
20935cd
77616b1
20935cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9546e
 
20935cd
7f9546e
 
20935cd
 
7f9546e
20935cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77616b1
 
 
20935cd
 
e6d300b
20935cd
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# utils 

from langchain_chroma import Chroma
from langchain_nomic.embeddings import NomicEmbeddings
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import CohereRerank
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_groq import ChatGroq

from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableMap
from langchain.schema import BaseRetriever
from qdrant_client import models


from langchain_huggingface.embeddings import HuggingFaceEmbeddings

load_dotenv()
#Retriever
def retriever(n_docs=5):
    vector_database_path = "chromadb3"

    #embeddings_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
    embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")


    vectorstore = Chroma(collection_name="chroma_db",
                        persist_directory=vector_database_path,
                    embedding_function=embedding_model)

    vs_retriever = vectorstore.as_retriever(k=n_docs)

    texts = vectorstore.get()['documents']
    metadatas = vectorstore.get()["metadatas"]

    documents = []
    for i in range(len(texts)):
        doc = Document(page_content=texts[i], metadata=metadatas[i])
        documents.append(doc)

    keyword_retriever = BM25Retriever.from_documents(documents)
    keyword_retriever.k =  n_docs

    ensemble_retriever = EnsembleRetriever(retrievers=[vs_retriever,keyword_retriever],
                                       weights=[0.5, 0.5])
    
    compressor = CohereRerank(model="rerank-english-v3.0")
    retriever = ContextualCompressionRetriever(
        base_compressor=compressor, base_retriever=ensemble_retriever
    )

    return retriever

#Retriever prompt
rag_prompt = """You are a medical chatbot designed to answer health-related questions.
The questions you will receive will primarily focus on medical topics and patient care.
Here is the context to use to answer the question:
{context}
Think carefully about the above context.
Now, review the user question:
{input}
Provide an answer to this question using only the above context.
Answer:"""

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

#RAG chain
def get_expression_chain(retriever: BaseRetriever, model_name="llama-3.1-70b-versatile", temp=0 ) -> Runnable:
    """Return a chain defined primarily in LangChain Expression Language"""
    def retrieve_context(input_text):
        # Use the retriever to fetch relevant documents
        docs = retriever.get_relevant_documents(input_text)
        return format_docs(docs)
    
    ingress = RunnableMap(
        {
            "input": lambda x: x["input"],
            "context": lambda x: retrieve_context(x["input"]),
        }
    )
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                rag_prompt
            )
        ]
    )
    llm = ChatGroq(model=model_name, temperature=temp)

    chain = ingress | prompt | llm
    return chain

#embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

#Generate embeddings for a given text
def get_embeddings(text):
    return embedding_model.embed_query([text])[0] #, task_type='search_document'


# Create or connect to a Qdrant collection
def create_qdrant_collection(client, collection_name):
    if collection_name not in client.get_collections().collections:
        client.create_collection(
            collection_name=collection_name,
            vectors_config=models.VectorParams(size=768, distance=models.Distance.COSINE)
        )