Spaces:
Build error
Build error
File size: 5,552 Bytes
d660b02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import json
from typing import Any, Dict, Optional
from loguru import logger
from threading import Lock
try:
import boto3
except ModuleNotFoundError:
logger.warning("Couldn't load AWS or SageMaker imports. Run 'poetry install --with aws' to support AWS.")
from langchain_ollama import ChatOllama
from llm_engineering.domain.inference import Inference
from llm_engineering.settings import settings
from langchain.schema import AIMessage, HumanMessage, SystemMessage
class LLMInferenceSagemakerEndpoint(Inference):
"""
Class for performing inference using a SageMaker endpoint for LLM schemas.
"""
def __init__(
self,
endpoint_name: str,
default_payload: Optional[Dict[str, Any]] = None,
inference_component_name: Optional[str] = None,
) -> None:
super().__init__()
self.client = boto3.client(
"sagemaker-runtime",
region_name=settings.AWS_REGION,
aws_access_key_id=settings.AWS_ACCESS_KEY,
aws_secret_access_key=settings.AWS_SECRET_KEY,
)
self.endpoint_name = endpoint_name
self.payload = default_payload if default_payload else self._default_payload()
self.inference_component_name = inference_component_name
def _default_payload(self) -> Dict[str, Any]:
"""
Generates the default payload for the inference request.
Returns:
dict: The default payload.
"""
return {
"inputs": "How is the weather?",
"parameters": {
"max_new_tokens": settings.MAX_NEW_TOKENS_INFERENCE,
"top_p": settings.TOP_P_INFERENCE,
"temperature": settings.TEMPERATURE_INFERENCE,
"return_full_text": False,
},
}
def set_payload(self, inputs: str, parameters: Optional[Dict[str, Any]] = None) -> None:
"""
Sets the payload for the inference request.
Args:
inputs (str): The input text for the inference.
parameters (dict, optional): Additional parameters for the inference. Defaults to None.
"""
print("FYOU !")
self.payload["inputs"] = inputs
if parameters:
self.payload["parameters"].update(parameters)
print("FYOU")
def inference(self) -> Dict[str, Any]:
"""
Performs the inference request using the SageMaker endpoint.
Returns:
dict: The response from the inference request.
Raises:
Exception: If an error occurs during the inference request.
"""
try:
logger.info("Inference request sent.")
invoke_args = {
"EndpointName": self.endpoint_name,
"ContentType": "application/json",
"Body": json.dumps(self.payload),
}
if self.inference_component_name not in ["None", None]:
invoke_args["InferenceComponentName"] = self.inference_component_name
response = self.client.invoke_endpoint(**invoke_args)
response_body = response["Body"].read().decode("utf8")
return json.loads(response_body)
except Exception:
logger.exception("SageMaker inference failed.")
raise
class LLMInferenceOLLAMA(Inference):
"""
Class for performing inference using a SageMaker endpoint for LLM schemas.
Implements Singleton design pattern.
"""
_instance = None
_lock = Lock() # For thread safety
def __new__(cls, model_name: str):
# Ensure thread-safe singleton instance creation
if not cls._instance:
with cls._lock:
if not cls._instance:
print("Creating new instance")
cls._instance = super().__new__(cls)
else:
print("Using existing instance")
return cls._instance
def __init__(self, model_name: str) -> None:
# Only initialize once
if not hasattr(self, "_initialized"):
super().__init__()
self.payload = []
self.llm = ChatOllama(
model=model_name,
temperature=0.7,
)
self._initialized = True # Flag to prevent reinitialization
def set_payload(self, query: str, context: str | None, parameters: Optional[Dict[str, Any]] = None) -> None:
"""
Sets the payload for the inference request.
Args:
inputs (str): The input text for the inference.
parameters (dict, optional): Additional parameters for the inference. Defaults to None.
"""
self.payload = [
SystemMessage(content='You are a helpful Assistant that answers questions of the user accurately given its knowledge and the provided context that was found in the external database'),
SystemMessage(content=context),
query,
]
return
def inference(self) -> Dict[str, Any]:
"""
Performs the inference request using the SageMaker endpoint.
Returns:
dict: The response from the inference request.
Raises:
Exception: If an error occurs during the inference request.
"""
print(self.payload)
return self.llm.invoke(self.payload)
|