ChatIGL / langchain_src /qna_chain.py
Koshti10's picture
Upload 51 files
9610b37 verified
raw
history blame
2.36 kB
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_openai import ChatOpenAI
from langchain_community.document_transformers import LongContextReorder
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.messages import HumanMessage
def generate_response(input, history):
vector_db = Chroma(persist_directory="vector_db", embedding_function=OpenAIEmbeddings())
# Multi-query retriever
retriever = vector_db.as_retriever(
search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.7, "k": 10}
)
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.2)
retriever = MultiQueryRetriever.from_llm(
retriever=retriever, llm=llm
)
unique_docs = retriever.invoke(input)
reordering = LongContextReorder()
reordered_docs = reordering.transform_documents(unique_docs)
SYSTEM_TEMPLATE = """
Answer the user's questions based on the below context.
The context will always be relevant to the question.
Use chat history if required to answer the question.
Always be very descriptive, unless stated otherwise.
Explain as if you are a master at this subject.
<context>
{context}
</context>
<chat_history>
{chat_history}
</chat_history>
"""
question_answering_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
SYSTEM_TEMPLATE,
),
MessagesPlaceholder(variable_name="messages"),
]
)
document_chain = create_stuff_documents_chain(llm, question_answering_prompt)
reponse = document_chain.invoke(
{
"context": reordered_docs,
"messages": [
HumanMessage(content=input)
],
"chat_history": history
}
)
return reponse
if __name__ == "__main__":
prompt = "How should my MDPE pipe cross a Nallah?"
response = generate_response(prompt)
print(response)