Spaces:
Sleeping
Sleeping
from langchain_core.runnables import RunnablePassthrough | |
from langchain_openai import ChatOpenAI | |
from operator import itemgetter | |
from prompt_templates import PromptTemplates | |
class RetrievalManager: | |
""" | |
RetrievalManager class. | |
This class represents a retrieval manager that processes questions using a retrieval-augmented QA chain and returns the response. | |
Attributes: | |
retriever (object): The retriever object used for retrieval. | |
chat_model (object): The ChatOpenAI object representing the OpenAI Chat model. | |
Methods: | |
notebook_QA(question): | |
Processes a question using the retrieval-augmented QA chain and returns the response. | |
""" | |
def __init__(self, retriever): | |
self.retriever = retriever | |
self.chat_model = ChatOpenAI(model="gpt-4-turbo", temperature=0.1) | |
self.prompts = PromptTemplates() | |
def notebook_QA(self, question): | |
""" | |
Processes a question using the retrieval-augmented QA chain and returns the response. | |
Parameters: | |
question (str): The question to be processed. | |
Returns: | |
str: The response generated by the retrieval-augmented QA chain. | |
""" | |
retrieval_augmented_qa_chain = ( | |
{"context": itemgetter("question") | self.retriever, "question": itemgetter("question")} | |
| RunnablePassthrough.assign(context=itemgetter("context")) | |
| {"response": self.prompts.get_rag_qa_prompt() | self.chat_model, "context": itemgetter("context")} | |
) | |
response = retrieval_augmented_qa_chain.invoke({"question": question}) | |
return response["response"].content | |
def get_RAG_QA_chain(self): | |
return ( | |
{"context": itemgetter("question") | self.retriever, "question": itemgetter("question")} | |
| RunnablePassthrough.assign(context=itemgetter("context")) | |
| {"response": self.prompts.get_rag_qa_prompt() | self.chat_model, "context": itemgetter("context")} | |
) | |