Spaces:
Build error
Build error
import json | |
from typing import Any, Dict, Optional | |
from loguru import logger | |
from threading import Lock | |
try: | |
import boto3 | |
except ModuleNotFoundError: | |
logger.warning("Couldn't load AWS or SageMaker imports. Run 'poetry install --with aws' to support AWS.") | |
from langchain_ollama import ChatOllama | |
from llm_engineering.domain.inference import Inference | |
from llm_engineering.settings import settings | |
from langchain.schema import AIMessage, HumanMessage, SystemMessage | |
class LLMInferenceSagemakerEndpoint(Inference): | |
""" | |
Class for performing inference using a SageMaker endpoint for LLM schemas. | |
""" | |
def __init__( | |
self, | |
endpoint_name: str, | |
default_payload: Optional[Dict[str, Any]] = None, | |
inference_component_name: Optional[str] = None, | |
) -> None: | |
super().__init__() | |
self.client = boto3.client( | |
"sagemaker-runtime", | |
region_name=settings.AWS_REGION, | |
aws_access_key_id=settings.AWS_ACCESS_KEY, | |
aws_secret_access_key=settings.AWS_SECRET_KEY, | |
) | |
self.endpoint_name = endpoint_name | |
self.payload = default_payload if default_payload else self._default_payload() | |
self.inference_component_name = inference_component_name | |
def _default_payload(self) -> Dict[str, Any]: | |
""" | |
Generates the default payload for the inference request. | |
Returns: | |
dict: The default payload. | |
""" | |
return { | |
"inputs": "How is the weather?", | |
"parameters": { | |
"max_new_tokens": settings.MAX_NEW_TOKENS_INFERENCE, | |
"top_p": settings.TOP_P_INFERENCE, | |
"temperature": settings.TEMPERATURE_INFERENCE, | |
"return_full_text": False, | |
}, | |
} | |
def set_payload(self, inputs: str, parameters: Optional[Dict[str, Any]] = None) -> None: | |
""" | |
Sets the payload for the inference request. | |
Args: | |
inputs (str): The input text for the inference. | |
parameters (dict, optional): Additional parameters for the inference. Defaults to None. | |
""" | |
print("FYOU !") | |
self.payload["inputs"] = inputs | |
if parameters: | |
self.payload["parameters"].update(parameters) | |
print("FYOU") | |
def inference(self) -> Dict[str, Any]: | |
""" | |
Performs the inference request using the SageMaker endpoint. | |
Returns: | |
dict: The response from the inference request. | |
Raises: | |
Exception: If an error occurs during the inference request. | |
""" | |
try: | |
logger.info("Inference request sent.") | |
invoke_args = { | |
"EndpointName": self.endpoint_name, | |
"ContentType": "application/json", | |
"Body": json.dumps(self.payload), | |
} | |
if self.inference_component_name not in ["None", None]: | |
invoke_args["InferenceComponentName"] = self.inference_component_name | |
response = self.client.invoke_endpoint(**invoke_args) | |
response_body = response["Body"].read().decode("utf8") | |
return json.loads(response_body) | |
except Exception: | |
logger.exception("SageMaker inference failed.") | |
raise | |
class LLMInferenceOLLAMA(Inference): | |
""" | |
Class for performing inference using a SageMaker endpoint for LLM schemas. | |
Implements Singleton design pattern. | |
""" | |
_instance = None | |
_lock = Lock() # For thread safety | |
def __new__(cls, model_name: str): | |
# Ensure thread-safe singleton instance creation | |
if not cls._instance: | |
with cls._lock: | |
if not cls._instance: | |
print("Creating new instance") | |
cls._instance = super().__new__(cls) | |
else: | |
print("Using existing instance") | |
return cls._instance | |
def __init__(self, model_name: str) -> None: | |
# Only initialize once | |
if not hasattr(self, "_initialized"): | |
super().__init__() | |
self.payload = [] | |
self.llm = ChatOllama( | |
model=model_name, | |
temperature=0.7, | |
) | |
self._initialized = True # Flag to prevent reinitialization | |
def set_payload(self, query: str, context: str | None, parameters: Optional[Dict[str, Any]] = None) -> None: | |
""" | |
Sets the payload for the inference request. | |
Args: | |
inputs (str): The input text for the inference. | |
parameters (dict, optional): Additional parameters for the inference. Defaults to None. | |
""" | |
self.payload = [ | |
SystemMessage(content='You are a helpful Assistant that answers questions of the user accurately given its knowledge and the provided context that was found in the external database'), | |
SystemMessage(content=context), | |
query, | |
] | |
return | |
def inference(self) -> Dict[str, Any]: | |
""" | |
Performs the inference request using the SageMaker endpoint. | |
Returns: | |
dict: The response from the inference request. | |
Raises: | |
Exception: If an error occurs during the inference request. | |
""" | |
print(self.payload) | |
return self.llm.invoke(self.payload) | |