File size: 1,253 Bytes
c0cd1dc
 
 
a553e02
 
 
c0cd1dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a553e02
c0cd1dc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

from typing import Any, Generator, Optional, Protocol, Tuple, runtime_checkable

from llama_index.core.service_context_elements.llm_predictor import  LLMPredictor
from llama_index.core.llms.utils import LLMType
from llama_index.core.callbacks.base import CallbackManager

from kron.llm_predictor.utils import kron_resolve_llm

class KronLLMPredictor(LLMPredictor):
    """LLM predictor class.

    Wrapper around an LLMChain from Langchain.

    Args:
        llm (Optional[langchain.llms.base.LLM]): LLM from Langchain to use
            for predictions. Defaults to OpenAI's text-davinci-003 model.
            Please see `Langchain's LLM Page
            <https://langchain.readthedocs.io/en/latest/modules/llms.html>`_
            for more details.

        retry_on_throttling (bool): Whether to retry on rate limit errors.
            Defaults to true.

        cache (Optional[langchain.cache.BaseCache]) : use cached result for LLM
    """

    def __init__(
        self,
        llm: Optional[LLMType] = None,
        callback_manager: Optional[CallbackManager] = None,
     ) -> None:
        """Initialize params."""
        self._llm = kron_resolve_llm(llm)
        self._llm.callback_manager = callback_manager or CallbackManager([])