File size: 1,253 Bytes
c0cd1dc a553e02 c0cd1dc a553e02 c0cd1dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from typing import Any, Generator, Optional, Protocol, Tuple, runtime_checkable
from llama_index.core.service_context_elements.llm_predictor import LLMPredictor
from llama_index.core.llms.utils import LLMType
from llama_index.core.callbacks.base import CallbackManager
from kron.llm_predictor.utils import kron_resolve_llm
class KronLLMPredictor(LLMPredictor):
"""LLM predictor class.
Wrapper around an LLMChain from Langchain.
Args:
llm (Optional[langchain.llms.base.LLM]): LLM from Langchain to use
for predictions. Defaults to OpenAI's text-davinci-003 model.
Please see `Langchain's LLM Page
<https://langchain.readthedocs.io/en/latest/modules/llms.html>`_
for more details.
retry_on_throttling (bool): Whether to retry on rate limit errors.
Defaults to true.
cache (Optional[langchain.cache.BaseCache]) : use cached result for LLM
"""
def __init__(
self,
llm: Optional[LLMType] = None,
callback_manager: Optional[CallbackManager] = None,
) -> None:
"""Initialize params."""
self._llm = kron_resolve_llm(llm)
self._llm.callback_manager = callback_manager or CallbackManager([])
|