Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# @Desc : the implement of memory storage | |
# https://github.com/geekan/MetaGPT/blob/main/metagpt/memory/memory_storage.py | |
from typing import List | |
from pathlib import Path | |
from langchain.vectorstores.faiss import FAISS | |
from autoagents.system.const import DATA_PATH, MEM_TTL | |
from autoagents.system.logs import logger | |
from autoagents.system.schema import Message | |
from autoagents.system.utils.serialize import serialize_message, deserialize_message | |
from autoagents.system.document_store.faiss_store import FaissStore | |
class MemoryStorage(FaissStore): | |
""" | |
The memory storage with Faiss as ANN search engine | |
""" | |
def __init__(self, mem_ttl: int = MEM_TTL): | |
self.role_id: str = None | |
self.role_mem_path: str = None | |
self.mem_ttl: int = mem_ttl # later use | |
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories | |
self._initialized: bool = False | |
self.store: FAISS = None # Faiss engine | |
def is_initialized(self) -> bool: | |
return self._initialized | |
def recover_memory(self, role_id: str) -> List[Message]: | |
self.role_id = role_id | |
self.role_mem_path = Path(DATA_PATH / f'role_mem/{self.role_id}/') | |
self.role_mem_path.mkdir(parents=True, exist_ok=True) | |
self.store = self._load() | |
messages = [] | |
if not self.store: | |
# TODO init `self.store` under here with raw faiss api instead under `add` | |
pass | |
else: | |
for _id, document in self.store.docstore._dict.items(): | |
messages.append(deserialize_message(document.metadata.get("message_ser"))) | |
self._initialized = True | |
return messages | |
def _get_index_and_store_fname(self): | |
if not self.role_mem_path: | |
logger.error(f'You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory') | |
return None, None | |
index_fpath = Path(self.role_mem_path / f'{self.role_id}.index') | |
storage_fpath = Path(self.role_mem_path / f'{self.role_id}.pkl') | |
return index_fpath, storage_fpath | |
def persist(self): | |
super(MemoryStorage, self).persist() | |
logger.debug(f'Agent {self.role_id} persist memory into local') | |
def add(self, message: Message) -> bool: | |
""" add message into memory storage""" | |
docs = [message.content] | |
metadatas = [{"message_ser": serialize_message(message)}] | |
if not self.store: | |
# init Faiss | |
self.store = self._write(docs, metadatas) | |
self._initialized = True | |
else: | |
self.store.add_texts(texts=docs, metadatas=metadatas) | |
self.persist() | |
logger.info(f"Agent {self.role_id}'s memory_storage add a message") | |
def search(self, message: Message, k=4) -> List[Message]: | |
"""search for dissimilar messages""" | |
if not self.store: | |
return [] | |
resp = self.store.similarity_search_with_score( | |
query=message.content, | |
k=k | |
) | |
# filter the result which score is smaller than the threshold | |
filtered_resp = [] | |
for item, score in resp: | |
# the smaller score means more similar relation | |
if score < self.threshold: | |
continue | |
# convert search result into Memory | |
metadata = item.metadata | |
new_mem = deserialize_message(metadata.get("message_ser")) | |
filtered_resp.append(new_mem) | |
return filtered_resp | |
def clean(self): | |
index_fpath, storage_fpath = self._get_index_and_store_fname() | |
if index_fpath and index_fpath.exists(): | |
index_fpath.unlink(missing_ok=True) | |
if storage_fpath and storage_fpath.exists(): | |
storage_fpath.unlink(missing_ok=True) | |
self.store = None | |
self._initialized = False | |