File size: 1,205 Bytes
c0cd1dc 82caff6 c0cd1dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from typing import Any, Optional, List
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Baseten
class KronBasetenCamelLLM(Baseten):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to Baseten deployed model endpoint."""
try:
import baseten
except ImportError as exc:
raise ImportError(
"Could not import Baseten Python package. "
"Please install it with `pip install baseten`."
) from exc
# get the model and version
try:
model = baseten.deployed_model_version_id(self.model)
response = model.predict({"instruction": prompt, **kwargs})
except baseten.common.core.ApiError:
model = baseten.deployed_model_id(self.model)
response = model.predict({"instruction": prompt, **kwargs})
print(f'baseten response: {response}')
response_txt = response['completion']
#print(f'\n********{response_txt}')
return response_txt |