Spaces:
Sleeping
Sleeping
from aimakerspace.openai_utils.prompts import ( | |
UserRolePrompt, | |
SystemRolePrompt, | |
AssistantRolePrompt, | |
) | |
from aimakerspace.vectordatabase import VectorDatabase | |
from aimakerspace.openai_utils.chatmodel import ChatOpenAI | |
class RetrievalAugmentedQAPipeline: | |
def __init__( | |
self, | |
system_role_prompt: SystemRolePrompt, | |
user_role_prompt: UserRolePrompt, | |
llm: ChatOpenAI(), | |
vector_db_retriever: VectorDatabase, | |
) -> None: | |
self.system_role_prompt = system_role_prompt | |
self.user_role_prompt = user_role_prompt | |
self.llm = llm | |
self.vector_db_retriever = vector_db_retriever | |
async def arun_pipeline(self, user_query: str): | |
context_list = self.vector_db_retriever.search_by_text(user_query, k=4) | |
context_prompt = "" | |
for context in context_list: | |
context_prompt += context[0] + "\n" | |
formatted_system_prompt = self.system_role_prompt.create_message() | |
formatted_user_prompt = self.user_role_prompt.create_message( | |
question=user_query, context=context_prompt | |
) | |
async def generate_response(): | |
async for chunk in self.llm.astream( | |
[formatted_system_prompt, formatted_user_prompt] | |
): | |
yield chunk | |
return {"response": generate_response(), "context": context_list} | |