SUPASSISTANT / utils.py
benticha's picture
initial commit
4ff2d98
raw
history blame
3 kB
from langchain_chroma import Chroma
from langchain_nomic.embeddings import NomicEmbeddings
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import CohereRerank
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain_groq import ChatGroq
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableMap
from langchain.schema import BaseRetriever
load_dotenv()
def retriever(n_docs=5):
vector_database_path = "sup-knowledge-eng-nomic"
embeddings_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
vectorstore = Chroma(collection_name="sup-store-eng-nomic",
persist_directory=vector_database_path,
embedding_function=embeddings_model)
vs_retriever = vectorstore.as_retriever(k=n_docs)
texts = vectorstore.get()['documents']
metadatas = vectorstore.get()["metadatas"]
documents = []
for i in range(len(texts)):
doc = Document(page_content=texts[i], metadata=metadatas[i])
documents.append(doc)
keyword_retriever = BM25Retriever.from_documents(documents)
keyword_retriever.k = n_docs
ensemble_retriever = EnsembleRetriever(retrievers=[vs_retriever,keyword_retriever],
weights=[0.5, 0.5])
compressor = CohereRerank(model="rerank-english-v3.0")
retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
return retriever
rag_prompt = """You are an assistant for question-answering tasks.
The questions that you will be asked will mainly be about SUP'COM (also known as Higher School Of Communication Of Tunis).
Here is the context to use to answer the question:
{context}
Think carefully about the above context.
Now, review the user question:
{input}
Provide an answer to this questions using only the above context.
Answer:"""
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
def get_expression_chain(retriever: BaseRetriever, model_name="llama-3.1-70b-versatile", temp=0
) -> Runnable:
"""Return a chain defined primarily in LangChain Expression Language"""
def retrieve_context(input_text):
# Use the retriever to fetch relevant documents
docs = retriever.get_relevant_documents(input_text)
return format_docs(docs)
ingress = RunnableMap(
{
"input": lambda x: x["input"],
"context": lambda x: retrieve_context(x["input"]),
}
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
rag_prompt
)
]
)
llm = ChatGroq(model=model_name, temperature=temp)
chain = ingress | prompt | llm
return chain