from aimakerspace.openai_utils.prompts import ( UserRolePrompt, SystemRolePrompt, AssistantRolePrompt, ) from aimakerspace.vectordatabase import VectorDatabase from aimakerspace.openai_utils.chatmodel import ChatOpenAI class RetrievalAugmentedQAPipeline: def __init__( self, system_role_prompt: SystemRolePrompt, user_role_prompt: UserRolePrompt, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase, ) -> None: self.system_role_prompt = system_role_prompt self.user_role_prompt = user_role_prompt self.llm = llm self.vector_db_retriever = vector_db_retriever async def arun_pipeline(self, user_query: str): context_list = self.vector_db_retriever.search_by_text(user_query, k=4) context_prompt = "" for context in context_list: context_prompt += context[0] + "\n" formatted_system_prompt = self.system_role_prompt.create_message() formatted_user_prompt = self.user_role_prompt.create_message( question=user_query, context=context_prompt ) async def generate_response(): async for chunk in self.llm.astream( [formatted_system_prompt, formatted_user_prompt] ): yield chunk return {"response": generate_response(), "context": context_list}