|
import os |
|
import typing |
|
import json |
|
from langchain_community.llms import SagemakerEndpoint |
|
from langchain.llms.sagemaker_endpoint import LLMContentHandler |
|
from pydantic.v1 import root_validator |
|
|
|
from utils import FakeTokenizer |
|
|
|
|
|
class ChatContentHandler(LLMContentHandler): |
|
content_type = "application/json" |
|
accepts = "application/json" |
|
|
|
def transform_input(self, prompt: str, model_kwargs: typing.Dict) -> bytes: |
|
messages0 = [] |
|
openai_system_prompt = "You are a helpful assistant." |
|
if openai_system_prompt: |
|
messages0.append({"role": "system", "content": openai_system_prompt}) |
|
messages0.append({'role': 'user', 'content': prompt}) |
|
input_dict = {'inputs': [messages0], "parameters": model_kwargs} |
|
return json.dumps(input_dict).encode("utf-8") |
|
|
|
def transform_output(self, output: bytes) -> str: |
|
response_json = json.loads(output.read().decode("utf-8")) |
|
return response_json[0]["generation"]['content'] |
|
|
|
|
|
class BaseContentHandler(LLMContentHandler): |
|
content_type = "application/json" |
|
accepts = "application/json" |
|
|
|
def transform_input(self, prompt: str, model_kwargs: typing.Dict) -> bytes: |
|
input_dict = {'inputs': prompt, "parameters": model_kwargs} |
|
return json.dumps(input_dict).encode("utf-8") |
|
|
|
def transform_output(self, output: bytes) -> str: |
|
response_json = json.loads(output.read().decode("utf-8")) |
|
return response_json[0]["generation"] |
|
|
|
|
|
class H2OSagemakerEndpoint(SagemakerEndpoint): |
|
aws_access_key_id: str = "" |
|
aws_secret_access_key: str = "" |
|
tokenizer: typing.Any = None |
|
|
|
@root_validator() |
|
def validate_environment(cls, values: typing.Dict) -> typing.Dict: |
|
"""Validate that AWS credentials to and python package exists in environment.""" |
|
try: |
|
import boto3 |
|
|
|
try: |
|
if values["credentials_profile_name"] is not None: |
|
session = boto3.Session( |
|
profile_name=values["credentials_profile_name"] |
|
) |
|
else: |
|
|
|
session = boto3.Session() |
|
|
|
values["client"] = session.client( |
|
"sagemaker-runtime", |
|
region_name=values['region_name'], |
|
aws_access_key_id=values['aws_access_key_id'], |
|
aws_secret_access_key=values['aws_secret_access_key'], |
|
) |
|
|
|
except Exception as e: |
|
raise ValueError( |
|
"Could not load credentials to authenticate with AWS client. " |
|
"Please check that credentials in the specified " |
|
"profile name are valid." |
|
) from e |
|
|
|
except ImportError: |
|
raise ImportError( |
|
"Could not import boto3 python package. " |
|
"Please install it with `pip install boto3`." |
|
) |
|
return values |
|
|
|
def get_token_ids(self, text: str) -> typing.List[int]: |
|
tokenizer = self.tokenizer |
|
if tokenizer is not None: |
|
return tokenizer.encode(text) |
|
else: |
|
return FakeTokenizer().encode(text)['input_ids'] |
|
|
|
|