File size: 6,311 Bytes
54f5afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""Retrieval-augmented generation."""

from collections.abc import AsyncIterator, Iterator

from litellm import acompletion, completion, get_model_info  # type: ignore[attr-defined]

from raglite._config import RAGLiteConfig
from raglite._database import Chunk
from raglite._litellm import LlamaCppPythonLLM
from raglite._search import hybrid_search, rerank_chunks, retrieve_segments
from raglite._typing import SearchMethod

RAG_SYSTEM_PROMPT = """
You are a friendly and knowledgeable assistant that provides complete and insightful answers.
Answer the user's question using only the context below.
When responding, you MUST NOT reference the existence of the context, directly or indirectly.
Instead, you MUST treat the context as if its contents are entirely part of your working memory.
""".strip()


def _max_contexts(
    prompt: str,
    *,
    max_contexts: int = 5,
    context_neighbors: tuple[int, ...] | None = (-1, 1),
    messages: list[dict[str, str]] | None = None,
    config: RAGLiteConfig | None = None,
) -> int:
    """Determine the maximum number of contexts for RAG."""
    # If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up
    # to date by loading that LLM.
    config = config or RAGLiteConfig()
    if config.llm.startswith("llama-cpp-python"):
        _ = LlamaCppPythonLLM.llm(config.llm)
    # Get the model's maximum context size.
    llm_provider = "llama-cpp-python" if config.llm.startswith("llama-cpp") else None
    model_info = get_model_info(config.llm, custom_llm_provider=llm_provider)
    max_tokens = model_info.get("max_tokens") or 2048
    # Reduce the maximum number of contexts to take into account the LLM's context size.
    max_context_tokens = (
        max_tokens
        - sum(len(message["content"]) // 3 for message in messages or [])  # Previous messages.
        - len(RAG_SYSTEM_PROMPT) // 3  # System prompt.
        - len(prompt) // 3  # User prompt.
    )
    max_tokens_per_context = config.chunk_max_size // 3
    max_tokens_per_context *= 1 + len(context_neighbors or [])
    max_contexts = min(max_contexts, max_context_tokens // max_tokens_per_context)
    if max_contexts <= 0:
        error_message = "Not enough context tokens available for RAG."
        raise ValueError(error_message)
    return max_contexts


def _contexts(  # noqa: PLR0913
    prompt: str,
    *,
    max_contexts: int = 5,
    context_neighbors: tuple[int, ...] | None = (-1, 1),
    search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
    messages: list[dict[str, str]] | None = None,
    config: RAGLiteConfig | None = None,
) -> list[str]:
    """Retrieve contexts for RAG."""
    # Determine the maximum number of contexts.
    max_contexts = _max_contexts(
        prompt,
        max_contexts=max_contexts,
        context_neighbors=context_neighbors,
        messages=messages,
        config=config,
    )
    # Retrieve the top chunks.
    config = config or RAGLiteConfig()
    chunks: list[str] | list[Chunk]
    if callable(search):
        # If the user has configured a reranker, we retrieve extra contexts to rerank.
        extra_contexts = 3 * max_contexts if config.reranker else 0
        # Retrieve relevant contexts.
        chunk_ids, _ = search(prompt, num_results=max_contexts + extra_contexts, config=config)
        # Rerank the relevant contexts.
        chunks = rerank_chunks(query=prompt, chunk_ids=chunk_ids, config=config)
    else:
        # The user has passed a list of chunk_ids or chunks directly.
        chunks = search
    # Extend the top contexts with their neighbors and group chunks into contiguous segments.
    segments = retrieve_segments(chunks[:max_contexts], neighbors=context_neighbors, config=config)
    return segments


def rag(  # noqa: PLR0913
    prompt: str,
    *,
    max_contexts: int = 5,
    context_neighbors: tuple[int, ...] | None = (-1, 1),
    search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
    messages: list[dict[str, str]] | None = None,
    system_prompt: str = RAG_SYSTEM_PROMPT,
    config: RAGLiteConfig | None = None,
) -> Iterator[str]:
    """Retrieval-augmented generation."""
    # Get the contexts for RAG as contiguous segments of chunks.
    config = config or RAGLiteConfig()
    segments = _contexts(
        prompt,
        max_contexts=max_contexts,
        context_neighbors=context_neighbors,
        search=search,
        config=config,
    )
    system_prompt = f"{system_prompt}\n\n" + "\n\n".join(
        f'<context index="{i}">\n{segment.strip()}\n</context>'
        for i, segment in enumerate(segments)
    )
    # Stream the LLM response.
    stream = completion(
        model=config.llm,
        messages=[
            *(messages or []),
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt},
        ],
        stream=True,
    )
    for output in stream:
        token: str = output["choices"][0]["delta"].get("content") or ""
        yield token


async def async_rag(  # noqa: PLR0913
    prompt: str,
    *,
    max_contexts: int = 5,
    context_neighbors: tuple[int, ...] | None = (-1, 1),
    search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
    messages: list[dict[str, str]] | None = None,
    system_prompt: str = RAG_SYSTEM_PROMPT,
    config: RAGLiteConfig | None = None,
) -> AsyncIterator[str]:
    """Retrieval-augmented generation."""
    # Get the contexts for RAG as contiguous segments of chunks.
    config = config or RAGLiteConfig()
    segments = _contexts(
        prompt,
        max_contexts=max_contexts,
        context_neighbors=context_neighbors,
        search=search,
        config=config,
    )
    system_prompt = f"{system_prompt}\n\n" + "\n\n".join(
        f'<context index="{i}">\n{segment.strip()}\n</context>'
        for i, segment in enumerate(segments)
    )
    # Stream the LLM response.
    async_stream = await acompletion(
        model=config.llm,
        messages=[
            *(messages or []),
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt},
        ],
        stream=True,
    )
    async for output in async_stream:
        token: str = output["choices"][0]["delta"].get("content") or ""
        yield token