Orami01's picture
Upload 274 files
9c48ae2
raw
history blame
3.46 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Modified from https://github.com/geekan/MetaGPT/blob/main/metagpt/memory/memory.py
from collections import defaultdict
from typing import Iterable, Type
from autoagents.actions import Action
from autoagents.system.schema import Message
class Memory:
"""The most basic memory: super-memory"""
def __init__(self):
"""Initialize an empty storage list and an empty index dictionary"""
self.storage: list[Message] = []
self.index: dict[Type[Action], list[Message]] = defaultdict(list)
def add(self, message: Message):
"""Add a new message to storage, while updating the index"""
if message in self.storage:
return
self.storage.append(message)
if message.cause_by:
self.index[message.cause_by].append(message)
def add_batch(self, messages: Iterable[Message]):
for message in messages:
self.add(message)
def get_by_role(self, role: str) -> list[Message]:
"""Return all messages of a specified role"""
return [message for message in self.storage if message.role == role]
def get_by_content(self, content: str) -> list[Message]:
"""Return all messages containing a specified content"""
return [message for message in self.storage if content in message.content]
def delete(self, message: Message):
"""Delete the specified message from storage, while updating the index"""
self.storage.remove(message)
if message.cause_by and message in self.index[message.cause_by]:
self.index[message.cause_by].remove(message)
def clear(self):
"""Clear storage and index"""
self.storage = []
self.index = defaultdict(list)
def count(self) -> int:
"""Return the number of messages in storage"""
return len(self.storage)
def try_remember(self, keyword: str) -> list[Message]:
"""Try to recall all messages containing a specified keyword"""
return [message for message in self.storage if keyword in message.content]
def get(self, k=0) -> list[Message]:
"""Return the most recent k memories, return all when k=0"""
return self.storage[-k:]
def remember(self, observed: list[Message], k=10) -> list[Message]:
"""remember the most recent k memories from observed Messages, return all when k=0"""
already_observed = self.get(k)
news: list[Message] = []
for i in observed:
if i in already_observed:
continue
news.append(i)
return news
def get_by_action(self, action: Type[Action]) -> list[Message]:
"""Return all messages triggered by a specified Action"""
return self.index[action]
def get_by_actions(self, actions: Iterable[Type[Action]]) -> list[Message]:
"""Return all messages triggered by specified Actions"""
rsp = []
for action in actions:
if action not in self.index:
continue # return []
rsp += self.index[action]
return rsp
def get_by_and_actions(self, actions: Iterable[Type[Action]]) -> list[Message]:
"""Return all messages triggered by specified Actions"""
rsp = []
for action in actions:
if action not in self.index:
return []
rsp += self.index[action]
return rsp