Spaces:
Build error
Build error
File size: 3,166 Bytes
01523b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import re
from string import Template
from typing import List
from pydantic import Field, validator
from agentverse.initialization import load_llm
from agentverse.llms.base import BaseLLM
from agentverse.message import Message
from . import memory_registry
from .base import BaseMemory
@memory_registry.register("summary")
class SummaryMemory(BaseMemory):
llm: BaseLLM
messages: List[Message] = Field(default=[])
buffer: str = Field(default="")
recursive: bool = Field(default=False)
prompt_template: str = Field(default="")
def __init__(self, *args, **kwargs):
llm_config = kwargs.pop("llm")
llm = load_llm(llm_config)
super().__init__(llm=llm, *args, **kwargs)
@validator("prompt_template")
def check_prompt_template(cls, v, values):
"""Check if the prompt template is valid.
When recursive is True, the prompt template should contain the following arguments:
- $summary: The summary so far.
- $new_lines: The new lines to be added to the summary.
Otherwise, the prompt template should only contain $new_lines
"""
recursive = values.get("recursive")
summary_pat = re.compile(r"\$\{?summary\}?")
new_lines_pat = re.compile(r"\$\{?new_lines\}?")
if recursive:
if not summary_pat.search(v):
raise ValueError(
"When recursive is True, the prompt template should contain $summary."
)
if not new_lines_pat.search(v):
raise ValueError(
"When recursive is True, the prompt template should contain $new_lines."
)
else:
if summary_pat.search(v):
raise ValueError(
"When recursive is False, the prompt template should not contain $summary."
)
if not new_lines_pat.search(v):
raise ValueError(
"When recursive is False, the prompt template should contain $new_lines."
)
return v
def add_message(self, messages: List[Message]) -> None:
new_lines = "\n".join([message.content for message in messages])
self.update_buffer(new_lines)
def update_buffer(self, new_message: str):
prompt = self._fill_in_prompt_template(new_message)
response = self.llm.generate_response(prompt)
if self.recursive:
self.buffer = response.content
else:
self.buffer = "\n" + response.content
def _fill_in_prompt_template(self, new_lines: str) -> str:
"""Fill in the prompt template with the given arguments.
SummaryMemory supports the following arguments:
- summary: The summary so far.
- new_lines: The new lines to be added to the summary.
"""
input_arguments = {"summary": self.buffer, "new_lines": new_lines}
return Template(self.prompt_template).safe_substitute(input_arguments)
def to_string(self, *args, **kwargs) -> str:
return self.buffer
def reset(self) -> None:
self.messages = []
self.buffer = ""
|