Spaces:
Running
Running
# src/agents/rag_agent.py | |
from typing import List, Optional, Tuple, Dict | |
from ..llms.base_llm import BaseLLM | |
from src.embeddings.base_embedding import BaseEmbedding | |
from src.vectorstores.base_vectorstore import BaseVectorStore | |
from src.utils.text_splitter import split_text | |
from src.models.rag import RAGResponse | |
class RAGAgent: | |
def __init__( | |
self, | |
llm: BaseLLM, | |
embedding: BaseEmbedding, | |
vector_store: BaseVectorStore | |
): | |
self.llm = llm | |
self.embedding = embedding | |
self.vector_store = vector_store | |
def _convert_metadata_to_strings(self, metadata: Dict) -> Dict: | |
"""Convert numeric metadata values to strings""" | |
converted = {} | |
for key, value in metadata.items(): | |
if isinstance(value, (int, float)): | |
converted[key] = str(value) | |
else: | |
converted[key] = value | |
return converted | |
def retrieve_context( | |
self, | |
query: str, | |
top_k: int = 3 | |
) -> Tuple[List[str], List[Dict], Optional[List[float]]]: | |
""" | |
Retrieve relevant context documents for a given query | |
Args: | |
query (str): Input query to find context for | |
top_k (int): Number of top context documents to retrieve | |
Returns: | |
Tuple[List[str], List[Dict], Optional[List[float]]]: Retrieved documents, sources, and scores | |
""" | |
# Embed the query | |
query_embedding = self.embedding.embed_query(query) | |
# Retrieve similar documents with metadata and scores | |
results = self.vector_store.similarity_search( | |
query_embedding, | |
top_k=top_k | |
) | |
# Extract documents, sources, and scores from results | |
documents = [doc['text'] for doc in results] | |
# Convert numeric metadata values to strings | |
sources = [self._convert_metadata_to_strings(doc['metadata']) for doc in results] | |
scores = [doc['score'] for doc in results if doc.get('score') is not None] | |
# Only return scores if we have them for all documents | |
if len(scores) != len(documents): | |
scores = None | |
return documents, sources, scores | |
async def generate_response( | |
self, | |
query: str, | |
temperature: float = 0.7, | |
max_tokens: Optional[int] = None, | |
context_docs: Optional[List[str]] = None | |
) -> RAGResponse: | |
""" | |
Generate a response using RAG approach | |
Args: | |
query (str): User input query | |
temperature (float): Sampling temperature for the LLM | |
max_tokens (Optional[int]): Maximum tokens to generate | |
context_docs (Optional[List[str]]): Optional pre-provided context documents | |
Returns: | |
RAGResponse: Response with generated text and context | |
""" | |
# If no context provided, retrieve from vector store | |
if not context_docs: | |
context_docs, sources, scores = self.retrieve_context(query) | |
else: | |
sources = None | |
scores = None | |
# Construct augmented prompt with context | |
augmented_prompt = self._construct_prompt(query, context_docs) | |
# Generate response using LLM with temperature | |
response = self.llm.generate( | |
augmented_prompt, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
return RAGResponse( | |
response=response, | |
context_docs=context_docs, | |
sources=sources, | |
scores=scores | |
) | |
def _construct_prompt( | |
self, | |
query: str, | |
context_docs: List[str] | |
) -> str: | |
""" | |
Construct a prompt with retrieved context | |
Args: | |
query (str): Original user query | |
context_docs (List[str]): Retrieved context documents | |
Returns: | |
str: Augmented prompt for the LLM | |
""" | |
context_str = "\n\n".join(context_docs) | |
return f""" | |
Context Information: | |
{context_str} | |
User Query: {query} | |
Based on the context, please provide a comprehensive and accurate response. | |
""" |