Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
from typing import Any, Dict, List, Optional, Tuple, TypedDict | |
from litellm.types.llms.openai import AllMessageValues | |
from litellm.types.utils import StandardCallbackDynamicParams | |
class PromptManagementClient(TypedDict): | |
prompt_id: str | |
prompt_template: List[AllMessageValues] | |
prompt_template_model: Optional[str] | |
prompt_template_optional_params: Optional[Dict[str, Any]] | |
completed_messages: Optional[List[AllMessageValues]] | |
class PromptManagementBase(ABC): | |
def integration_name(self) -> str: | |
pass | |
def should_run_prompt_management( | |
self, | |
prompt_id: str, | |
dynamic_callback_params: StandardCallbackDynamicParams, | |
) -> bool: | |
pass | |
def _compile_prompt_helper( | |
self, | |
prompt_id: str, | |
prompt_variables: Optional[dict], | |
dynamic_callback_params: StandardCallbackDynamicParams, | |
) -> PromptManagementClient: | |
pass | |
def merge_messages( | |
self, | |
prompt_template: List[AllMessageValues], | |
client_messages: List[AllMessageValues], | |
) -> List[AllMessageValues]: | |
return prompt_template + client_messages | |
def compile_prompt( | |
self, | |
prompt_id: str, | |
prompt_variables: Optional[dict], | |
client_messages: List[AllMessageValues], | |
dynamic_callback_params: StandardCallbackDynamicParams, | |
) -> PromptManagementClient: | |
compiled_prompt_client = self._compile_prompt_helper( | |
prompt_id=prompt_id, | |
prompt_variables=prompt_variables, | |
dynamic_callback_params=dynamic_callback_params, | |
) | |
try: | |
messages = compiled_prompt_client["prompt_template"] + client_messages | |
except Exception as e: | |
raise ValueError( | |
f"Error compiling prompt: {e}. Prompt id={prompt_id}, prompt_variables={prompt_variables}, client_messages={client_messages}, dynamic_callback_params={dynamic_callback_params}" | |
) | |
compiled_prompt_client["completed_messages"] = messages | |
return compiled_prompt_client | |
def _get_model_from_prompt( | |
self, prompt_management_client: PromptManagementClient, model: str | |
) -> str: | |
if prompt_management_client["prompt_template_model"] is not None: | |
return prompt_management_client["prompt_template_model"] | |
else: | |
return model.replace("{}/".format(self.integration_name), "") | |
def get_chat_completion_prompt( | |
self, | |
model: str, | |
messages: List[AllMessageValues], | |
non_default_params: dict, | |
prompt_id: Optional[str], | |
prompt_variables: Optional[dict], | |
dynamic_callback_params: StandardCallbackDynamicParams, | |
) -> Tuple[str, List[AllMessageValues], dict]: | |
if prompt_id is None: | |
raise ValueError("prompt_id is required for Prompt Management Base class") | |
if not self.should_run_prompt_management( | |
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params | |
): | |
return model, messages, non_default_params | |
prompt_template = self.compile_prompt( | |
prompt_id=prompt_id, | |
prompt_variables=prompt_variables, | |
client_messages=messages, | |
dynamic_callback_params=dynamic_callback_params, | |
) | |
completed_messages = prompt_template["completed_messages"] or messages | |
prompt_template_optional_params = ( | |
prompt_template["prompt_template_optional_params"] or {} | |
) | |
updated_non_default_params = { | |
**non_default_params, | |
**prompt_template_optional_params, | |
} | |
model = self._get_model_from_prompt( | |
prompt_management_client=prompt_template, model=model | |
) | |
return model, completed_messages, updated_non_default_params | |