|
from typing import Any, Awaitable, Callable, Dict, Optional, Sequence |
|
|
|
from llama_index.core.bridge.langchain import BaseLanguageModel, BaseChatModel |
|
|
|
from llama_index.llms.openai import OpenAI |
|
|
|
from llama_index.core.base.llms.types import ( |
|
|
|
ChatMessage, |
|
ChatResponse, |
|
ChatResponseAsyncGen, |
|
ChatResponseGen, |
|
CompletionResponse, |
|
CompletionResponseAsyncGen, |
|
CompletionResponseGen, |
|
LLMMetadata, |
|
) |
|
|
|
from llama_index.llms.openai.utils import ( |
|
from_openai_message, |
|
is_chat_model, |
|
is_function_calling_model, |
|
openai_modelname_to_contextsize, |
|
resolve_openai_credentials, |
|
to_openai_message_dicts, |
|
) |
|
|
|
from kron.llm_predictor.openai_utils import kron_openai_modelname_to_contextsize |
|
|
|
class KronOpenAI(OpenAI): |
|
|
|
@property |
|
def metadata(self) -> LLMMetadata: |
|
return LLMMetadata( |
|
context_window=kron_openai_modelname_to_contextsize(self.model), |
|
num_output=self.max_tokens or -1, |
|
is_chat_model=is_chat_model(model=self._get_model_name()), |
|
is_function_calling_model=is_function_calling_model(model=self._get_model_name()), |
|
model_name=self.model, |
|
) |
|
|
|
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: |
|
|
|
response = super().complete(prompt, **kwargs) |
|
text = response.text |
|
text = text.strip() |
|
|
|
text = text.split("<|endoftext|>")[0] |
|
response.text = text |
|
return response |