Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# Modified from https://github.com/geekan/MetaGPT/blob/main/metagpt/memory/memory.py | |
from collections import defaultdict | |
from typing import Iterable, Type | |
from autoagents.actions import Action | |
from autoagents.system.schema import Message | |
class Memory: | |
"""The most basic memory: super-memory""" | |
def __init__(self): | |
"""Initialize an empty storage list and an empty index dictionary""" | |
self.storage: list[Message] = [] | |
self.index: dict[Type[Action], list[Message]] = defaultdict(list) | |
def add(self, message: Message): | |
"""Add a new message to storage, while updating the index""" | |
if message in self.storage: | |
return | |
self.storage.append(message) | |
if message.cause_by: | |
self.index[message.cause_by].append(message) | |
def add_batch(self, messages: Iterable[Message]): | |
for message in messages: | |
self.add(message) | |
def get_by_role(self, role: str) -> list[Message]: | |
"""Return all messages of a specified role""" | |
return [message for message in self.storage if message.role == role] | |
def get_by_content(self, content: str) -> list[Message]: | |
"""Return all messages containing a specified content""" | |
return [message for message in self.storage if content in message.content] | |
def delete(self, message: Message): | |
"""Delete the specified message from storage, while updating the index""" | |
self.storage.remove(message) | |
if message.cause_by and message in self.index[message.cause_by]: | |
self.index[message.cause_by].remove(message) | |
def clear(self): | |
"""Clear storage and index""" | |
self.storage = [] | |
self.index = defaultdict(list) | |
def count(self) -> int: | |
"""Return the number of messages in storage""" | |
return len(self.storage) | |
def try_remember(self, keyword: str) -> list[Message]: | |
"""Try to recall all messages containing a specified keyword""" | |
return [message for message in self.storage if keyword in message.content] | |
def get(self, k=0) -> list[Message]: | |
"""Return the most recent k memories, return all when k=0""" | |
return self.storage[-k:] | |
def remember(self, observed: list[Message], k=10) -> list[Message]: | |
"""remember the most recent k memories from observed Messages, return all when k=0""" | |
already_observed = self.get(k) | |
news: list[Message] = [] | |
for i in observed: | |
if i in already_observed: | |
continue | |
news.append(i) | |
return news | |
def get_by_action(self, action: Type[Action]) -> list[Message]: | |
"""Return all messages triggered by a specified Action""" | |
return self.index[action] | |
def get_by_actions(self, actions: Iterable[Type[Action]]) -> list[Message]: | |
"""Return all messages triggered by specified Actions""" | |
rsp = [] | |
for action in actions: | |
if action not in self.index: | |
continue # return [] | |
rsp += self.index[action] | |
return rsp | |
def get_by_and_actions(self, actions: Iterable[Type[Action]]) -> list[Message]: | |
"""Return all messages triggered by specified Actions""" | |
rsp = [] | |
for action in actions: | |
if action not in self.index: | |
return [] | |
rsp += self.index[action] | |
return rsp | |