Spaces:
Running
Running
# src/agents/rag_agent.py | |
from dataclasses import dataclass | |
from typing import List, Optional | |
from ..llms.base_llm import BaseLLM | |
from src.embeddings.base_embedding import BaseEmbedding | |
from src.vectorstores.base_vectorstore import BaseVectorStore | |
from src.utils.text_splitter import split_text | |
class RAGResponse: | |
response: str | |
context_docs: Optional[List[str]] = None | |
class RAGAgent: | |
def __init__( | |
self, | |
llm: BaseLLM, | |
embedding: BaseEmbedding, | |
vector_store: BaseVectorStore | |
): | |
self.llm = llm | |
self.embedding = embedding | |
self.vector_store = vector_store | |
def retrieve_context( | |
self, | |
query: str, | |
top_k: int = 3 | |
) -> List[str]: | |
""" | |
Retrieve relevant context documents for a given query | |
Args: | |
query (str): Input query to find context for | |
top_k (int): Number of top context documents to retrieve | |
Returns: | |
List[str]: List of retrieved context documents | |
""" | |
# Embed the query | |
query_embedding = self.embedding.embed_query(query) | |
# Retrieve similar documents | |
context_docs = self.vector_store.similarity_search( | |
query_embedding, | |
top_k=top_k | |
) | |
return context_docs | |
def generate_response( | |
self, | |
query: str, | |
context_docs: Optional[List[str]] = None | |
) -> RAGResponse: | |
""" | |
Generate a response using RAG approach | |
Args: | |
query (str): User input query | |
context_docs (Optional[List[str]]): Optional pre-provided context documents | |
Returns: | |
RAGResponse: Response with generated text and context | |
""" | |
# If no context provided, retrieve from vector store | |
if not context_docs: | |
context_docs = self.retrieve_context(query) | |
# Construct augmented prompt with context | |
augmented_prompt = self._construct_prompt(query, context_docs) | |
# Generate response using LLM | |
response = self.llm.generate(augmented_prompt) | |
return RAGResponse( | |
response=response, | |
context_docs=context_docs | |
) | |
def _construct_prompt( | |
self, | |
query: str, | |
context_docs: List[str] | |
) -> str: | |
""" | |
Construct a prompt with retrieved context | |
Args: | |
query (str): Original user query | |
context_docs (List[str]): Retrieved context documents | |
Returns: | |
str: Augmented prompt for the LLM | |
""" | |
context_str = "\n\n".join(context_docs) | |
return f""" | |
Context Information: | |
{context_str} | |
User Query: {query} | |
Based on the context, please provide a comprehensive and accurate response. | |
""" |