File size: 2,978 Bytes
640b1c8
 
 
 
 
e87abff
 
 
640b1c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# src/agents/rag_agent.py
from dataclasses import dataclass
from typing import List, Optional

from ..llms.base_llm import BaseLLM
from src.embeddings.base_embedding import BaseEmbedding
from src.vectorstores.base_vectorstore import BaseVectorStore
from src.utils.text_splitter import split_text

@dataclass
class RAGResponse:
    response: str
    context_docs: Optional[List[str]] = None

class RAGAgent:
    def __init__(
        self, 
        llm: BaseLLM, 
        embedding: BaseEmbedding, 
        vector_store: BaseVectorStore
    ):
        self.llm = llm
        self.embedding = embedding
        self.vector_store = vector_store
    
    def retrieve_context(
        self, 
        query: str, 
        top_k: int = 3
    ) -> List[str]:
        """
        Retrieve relevant context documents for a given query
        
        Args:
            query (str): Input query to find context for
            top_k (int): Number of top context documents to retrieve
        
        Returns:
            List[str]: List of retrieved context documents
        """
        # Embed the query
        query_embedding = self.embedding.embed_query(query)
        
        # Retrieve similar documents
        context_docs = self.vector_store.similarity_search(
            query_embedding, 
            top_k=top_k
        )
        
        return context_docs
    
    def generate_response(
        self, 
        query: str, 
        context_docs: Optional[List[str]] = None
    ) -> RAGResponse:
        """
        Generate a response using RAG approach
        
        Args:
            query (str): User input query
            context_docs (Optional[List[str]]): Optional pre-provided context documents
        
        Returns:
            RAGResponse: Response with generated text and context
        """
        # If no context provided, retrieve from vector store
        if not context_docs:
            context_docs = self.retrieve_context(query)
        
        # Construct augmented prompt with context
        augmented_prompt = self._construct_prompt(query, context_docs)
        
        # Generate response using LLM
        response = self.llm.generate(augmented_prompt)
        
        return RAGResponse(
            response=response, 
            context_docs=context_docs
        )
    
    def _construct_prompt(
        self, 
        query: str, 
        context_docs: List[str]
    ) -> str:
        """
        Construct a prompt with retrieved context
        
        Args:
            query (str): Original user query
            context_docs (List[str]): Retrieved context documents
        
        Returns:
            str: Augmented prompt for the LLM
        """
        context_str = "\n\n".join(context_docs)
        
        return f"""
        Context Information:
        {context_str}
        
        User Query: {query}
        
        Based on the context, please provide a comprehensive and accurate response.
        """