Spaces:
Running
Running
from typing import Optional, cast | |
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |
from core.memory.token_buffer_memory import TokenBufferMemory | |
from core.model_runtime.entities.message_entities import PromptMessage | |
from core.model_runtime.entities.model_entities import ModelPropertyKey | |
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |
from core.prompt.entities.advanced_prompt_entities import MemoryConfig | |
class PromptTransform: | |
def _append_chat_histories(self, memory: TokenBufferMemory, | |
memory_config: MemoryConfig, | |
prompt_messages: list[PromptMessage], | |
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: | |
rest_tokens = self._calculate_rest_token(prompt_messages, model_config) | |
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) | |
prompt_messages.extend(histories) | |
return prompt_messages | |
def _calculate_rest_token(self, prompt_messages: list[PromptMessage], | |
model_config: ModelConfigWithCredentialsEntity) -> int: | |
rest_tokens = 2000 | |
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) | |
if model_context_tokens: | |
model_type_instance = model_config.provider_model_bundle.model_type_instance | |
model_type_instance = cast(LargeLanguageModel, model_type_instance) | |
curr_message_tokens = model_type_instance.get_num_tokens( | |
model_config.model, | |
model_config.credentials, | |
prompt_messages | |
) | |
max_tokens = 0 | |
for parameter_rule in model_config.model_schema.parameter_rules: | |
if (parameter_rule.name == 'max_tokens' | |
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): | |
max_tokens = (model_config.parameters.get(parameter_rule.name) | |
or model_config.parameters.get(parameter_rule.use_template)) or 0 | |
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens | |
rest_tokens = max(rest_tokens, 0) | |
return rest_tokens | |
def _get_history_messages_from_memory(self, memory: TokenBufferMemory, | |
memory_config: MemoryConfig, | |
max_token_limit: int, | |
human_prefix: Optional[str] = None, | |
ai_prefix: Optional[str] = None) -> str: | |
"""Get memory messages.""" | |
kwargs = { | |
"max_token_limit": max_token_limit | |
} | |
if human_prefix: | |
kwargs['human_prefix'] = human_prefix | |
if ai_prefix: | |
kwargs['ai_prefix'] = ai_prefix | |
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: | |
kwargs['message_limit'] = memory_config.window.size | |
return memory.get_history_prompt_text( | |
**kwargs | |
) | |
def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, | |
memory_config: MemoryConfig, | |
max_token_limit: int) -> list[PromptMessage]: | |
"""Get memory messages.""" | |
return memory.get_history_prompt_messages( | |
max_token_limit=max_token_limit, | |
message_limit=memory_config.window.size | |
if (memory_config.window.enabled | |
and memory_config.window.size is not None | |
and memory_config.window.size > 0) | |
else 10 | |
) | |