chatbot-backend / src /agents /rag_agent.py
TalatMasood's picture
Added support for multiple LLMs
e87abff
raw
history blame
2.98 kB
# src/agents/rag_agent.py
from dataclasses import dataclass
from typing import List, Optional
from ..llms.base_llm import BaseLLM
from src.embeddings.base_embedding import BaseEmbedding
from src.vectorstores.base_vectorstore import BaseVectorStore
from src.utils.text_splitter import split_text
@dataclass
class RAGResponse:
response: str
context_docs: Optional[List[str]] = None
class RAGAgent:
def __init__(
self,
llm: BaseLLM,
embedding: BaseEmbedding,
vector_store: BaseVectorStore
):
self.llm = llm
self.embedding = embedding
self.vector_store = vector_store
def retrieve_context(
self,
query: str,
top_k: int = 3
) -> List[str]:
"""
Retrieve relevant context documents for a given query
Args:
query (str): Input query to find context for
top_k (int): Number of top context documents to retrieve
Returns:
List[str]: List of retrieved context documents
"""
# Embed the query
query_embedding = self.embedding.embed_query(query)
# Retrieve similar documents
context_docs = self.vector_store.similarity_search(
query_embedding,
top_k=top_k
)
return context_docs
def generate_response(
self,
query: str,
context_docs: Optional[List[str]] = None
) -> RAGResponse:
"""
Generate a response using RAG approach
Args:
query (str): User input query
context_docs (Optional[List[str]]): Optional pre-provided context documents
Returns:
RAGResponse: Response with generated text and context
"""
# If no context provided, retrieve from vector store
if not context_docs:
context_docs = self.retrieve_context(query)
# Construct augmented prompt with context
augmented_prompt = self._construct_prompt(query, context_docs)
# Generate response using LLM
response = self.llm.generate(augmented_prompt)
return RAGResponse(
response=response,
context_docs=context_docs
)
def _construct_prompt(
self,
query: str,
context_docs: List[str]
) -> str:
"""
Construct a prompt with retrieved context
Args:
query (str): Original user query
context_docs (List[str]): Retrieved context documents
Returns:
str: Augmented prompt for the LLM
"""
context_str = "\n\n".join(context_docs)
return f"""
Context Information:
{context_str}
User Query: {query}
Based on the context, please provide a comprehensive and accurate response.
"""