mlk8s / kron /llm_predictor /KronBasetenCamelLLM.py
Arylwen's picture
v0.0.6 sidebar
82caff6
from typing import Any, Optional, List
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Baseten
class KronBasetenCamelLLM(Baseten):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to Baseten deployed model endpoint."""
try:
import baseten
except ImportError as exc:
raise ImportError(
"Could not import Baseten Python package. "
"Please install it with `pip install baseten`."
) from exc
# get the model and version
try:
model = baseten.deployed_model_version_id(self.model)
response = model.predict({"instruction": prompt, **kwargs})
except baseten.common.core.ApiError:
model = baseten.deployed_model_id(self.model)
response = model.predict({"instruction": prompt, **kwargs})
print(f'baseten response: {response}')
response_txt = response['completion']
#print(f'\n********{response_txt}')
return response_txt