chatbot-backend / src /agents /rag_agent.py
TalatMasood's picture
Commit chatbot chnages
0739c8b
raw
history blame
4.37 kB
# src/agents/rag_agent.py
from typing import List, Optional, Tuple, Dict
from ..llms.base_llm import BaseLLM
from src.embeddings.base_embedding import BaseEmbedding
from src.vectorstores.base_vectorstore import BaseVectorStore
from src.utils.text_splitter import split_text
from src.models.rag import RAGResponse
class RAGAgent:
def __init__(
self,
llm: BaseLLM,
embedding: BaseEmbedding,
vector_store: BaseVectorStore
):
self.llm = llm
self.embedding = embedding
self.vector_store = vector_store
def _convert_metadata_to_strings(self, metadata: Dict) -> Dict:
"""Convert numeric metadata values to strings"""
converted = {}
for key, value in metadata.items():
if isinstance(value, (int, float)):
converted[key] = str(value)
else:
converted[key] = value
return converted
def retrieve_context(
self,
query: str,
top_k: int = 3
) -> Tuple[List[str], List[Dict], Optional[List[float]]]:
"""
Retrieve relevant context documents for a given query
Args:
query (str): Input query to find context for
top_k (int): Number of top context documents to retrieve
Returns:
Tuple[List[str], List[Dict], Optional[List[float]]]: Retrieved documents, sources, and scores
"""
# Embed the query
query_embedding = self.embedding.embed_query(query)
# Retrieve similar documents with metadata and scores
results = self.vector_store.similarity_search(
query_embedding,
top_k=top_k
)
# Extract documents, sources, and scores from results
documents = [doc['text'] for doc in results]
# Convert numeric metadata values to strings
sources = [self._convert_metadata_to_strings(doc['metadata']) for doc in results]
scores = [doc['score'] for doc in results if doc.get('score') is not None]
# Only return scores if we have them for all documents
if len(scores) != len(documents):
scores = None
return documents, sources, scores
async def generate_response(
self,
query: str,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
context_docs: Optional[List[str]] = None
) -> RAGResponse:
"""
Generate a response using RAG approach
Args:
query (str): User input query
temperature (float): Sampling temperature for the LLM
max_tokens (Optional[int]): Maximum tokens to generate
context_docs (Optional[List[str]]): Optional pre-provided context documents
Returns:
RAGResponse: Response with generated text and context
"""
# If no context provided, retrieve from vector store
if not context_docs:
context_docs, sources, scores = self.retrieve_context(query)
else:
sources = None
scores = None
# Construct augmented prompt with context
augmented_prompt = self._construct_prompt(query, context_docs)
# Generate response using LLM with temperature
response = self.llm.generate(
augmented_prompt,
temperature=temperature,
max_tokens=max_tokens
)
return RAGResponse(
response=response,
context_docs=context_docs,
sources=sources,
scores=scores
)
def _construct_prompt(
self,
query: str,
context_docs: List[str]
) -> str:
"""
Construct a prompt with retrieved context
Args:
query (str): Original user query
context_docs (List[str]): Retrieved context documents
Returns:
str: Augmented prompt for the LLM
"""
context_str = "\n\n".join(context_docs)
return f"""
Context Information:
{context_str}
User Query: {query}
Based on the context, please provide a comprehensive and accurate response.
"""