|
from typing import Any, Optional, List |
|
from langchain.callbacks.manager import CallbackManagerForLLMRun |
|
from langchain.llms import Baseten |
|
|
|
class KronBasetenCamelLLM(Baseten): |
|
def _call( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
"""Call to Baseten deployed model endpoint.""" |
|
try: |
|
import baseten |
|
except ImportError as exc: |
|
raise ImportError( |
|
"Could not import Baseten Python package. " |
|
"Please install it with `pip install baseten`." |
|
) from exc |
|
|
|
|
|
try: |
|
model = baseten.deployed_model_version_id(self.model) |
|
response = model.predict({"instruction": prompt, **kwargs}) |
|
except baseten.common.core.ApiError: |
|
model = baseten.deployed_model_id(self.model) |
|
response = model.predict({"instruction": prompt, **kwargs}) |
|
|
|
print(f'baseten response: {response}') |
|
response_txt = response['completion'] |
|
|
|
return response_txt |