SkazuHD's picture
init space
d660b02
import json
from typing import Any, Dict, Optional
from loguru import logger
from threading import Lock
try:
import boto3
except ModuleNotFoundError:
logger.warning("Couldn't load AWS or SageMaker imports. Run 'poetry install --with aws' to support AWS.")
from langchain_ollama import ChatOllama
from llm_engineering.domain.inference import Inference
from llm_engineering.settings import settings
from langchain.schema import AIMessage, HumanMessage, SystemMessage
class LLMInferenceSagemakerEndpoint(Inference):
"""
Class for performing inference using a SageMaker endpoint for LLM schemas.
"""
def __init__(
self,
endpoint_name: str,
default_payload: Optional[Dict[str, Any]] = None,
inference_component_name: Optional[str] = None,
) -> None:
super().__init__()
self.client = boto3.client(
"sagemaker-runtime",
region_name=settings.AWS_REGION,
aws_access_key_id=settings.AWS_ACCESS_KEY,
aws_secret_access_key=settings.AWS_SECRET_KEY,
)
self.endpoint_name = endpoint_name
self.payload = default_payload if default_payload else self._default_payload()
self.inference_component_name = inference_component_name
def _default_payload(self) -> Dict[str, Any]:
"""
Generates the default payload for the inference request.
Returns:
dict: The default payload.
"""
return {
"inputs": "How is the weather?",
"parameters": {
"max_new_tokens": settings.MAX_NEW_TOKENS_INFERENCE,
"top_p": settings.TOP_P_INFERENCE,
"temperature": settings.TEMPERATURE_INFERENCE,
"return_full_text": False,
},
}
def set_payload(self, inputs: str, parameters: Optional[Dict[str, Any]] = None) -> None:
"""
Sets the payload for the inference request.
Args:
inputs (str): The input text for the inference.
parameters (dict, optional): Additional parameters for the inference. Defaults to None.
"""
print("FYOU !")
self.payload["inputs"] = inputs
if parameters:
self.payload["parameters"].update(parameters)
print("FYOU")
def inference(self) -> Dict[str, Any]:
"""
Performs the inference request using the SageMaker endpoint.
Returns:
dict: The response from the inference request.
Raises:
Exception: If an error occurs during the inference request.
"""
try:
logger.info("Inference request sent.")
invoke_args = {
"EndpointName": self.endpoint_name,
"ContentType": "application/json",
"Body": json.dumps(self.payload),
}
if self.inference_component_name not in ["None", None]:
invoke_args["InferenceComponentName"] = self.inference_component_name
response = self.client.invoke_endpoint(**invoke_args)
response_body = response["Body"].read().decode("utf8")
return json.loads(response_body)
except Exception:
logger.exception("SageMaker inference failed.")
raise
class LLMInferenceOLLAMA(Inference):
"""
Class for performing inference using a SageMaker endpoint for LLM schemas.
Implements Singleton design pattern.
"""
_instance = None
_lock = Lock() # For thread safety
def __new__(cls, model_name: str):
# Ensure thread-safe singleton instance creation
if not cls._instance:
with cls._lock:
if not cls._instance:
print("Creating new instance")
cls._instance = super().__new__(cls)
else:
print("Using existing instance")
return cls._instance
def __init__(self, model_name: str) -> None:
# Only initialize once
if not hasattr(self, "_initialized"):
super().__init__()
self.payload = []
self.llm = ChatOllama(
model=model_name,
temperature=0.7,
)
self._initialized = True # Flag to prevent reinitialization
def set_payload(self, query: str, context: str | None, parameters: Optional[Dict[str, Any]] = None) -> None:
"""
Sets the payload for the inference request.
Args:
inputs (str): The input text for the inference.
parameters (dict, optional): Additional parameters for the inference. Defaults to None.
"""
self.payload = [
SystemMessage(content='You are a helpful Assistant that answers questions of the user accurately given its knowledge and the provided context that was found in the external database'),
SystemMessage(content=context),
query,
]
return
def inference(self) -> Dict[str, Any]:
"""
Performs the inference request using the SageMaker endpoint.
Returns:
dict: The response from the inference request.
Raises:
Exception: If an error occurs during the inference request.
"""
print(self.payload)
return self.llm.invoke(self.payload)