File size: 3,949 Bytes
9c48ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc   : the implement of memory storage
# https://github.com/geekan/MetaGPT/blob/main/metagpt/memory/memory_storage.py

from typing import List
from pathlib import Path

from langchain.vectorstores.faiss import FAISS

from autoagents.system.const import DATA_PATH, MEM_TTL
from autoagents.system.logs import logger
from autoagents.system.schema import Message
from autoagents.system.utils.serialize import serialize_message, deserialize_message
from autoagents.system.document_store.faiss_store import FaissStore


class MemoryStorage(FaissStore):
    """
    The memory storage with Faiss as ANN search engine
    """

    def __init__(self, mem_ttl: int = MEM_TTL):
        self.role_id: str = None
        self.role_mem_path: str = None
        self.mem_ttl: int = mem_ttl  # later use
        self.threshold: float = 0.1  # experience value. TODO The threshold to filter similar memories
        self._initialized: bool = False

        self.store: FAISS = None  # Faiss engine

    @property
    def is_initialized(self) -> bool:
        return self._initialized

    def recover_memory(self, role_id: str) -> List[Message]:
        self.role_id = role_id
        self.role_mem_path = Path(DATA_PATH / f'role_mem/{self.role_id}/')
        self.role_mem_path.mkdir(parents=True, exist_ok=True)

        self.store = self._load()
        messages = []
        if not self.store:
            # TODO init `self.store` under here with raw faiss api instead under `add`
            pass
        else:
            for _id, document in self.store.docstore._dict.items():
                messages.append(deserialize_message(document.metadata.get("message_ser")))
            self._initialized = True

        return messages

    def _get_index_and_store_fname(self):
        if not self.role_mem_path:
            logger.error(f'You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory')
            return None, None
        index_fpath = Path(self.role_mem_path / f'{self.role_id}.index')
        storage_fpath = Path(self.role_mem_path / f'{self.role_id}.pkl')
        return index_fpath, storage_fpath

    def persist(self):
        super(MemoryStorage, self).persist()
        logger.debug(f'Agent {self.role_id} persist memory into local')

    def add(self, message: Message) -> bool:
        """ add message into memory storage"""
        docs = [message.content]
        metadatas = [{"message_ser": serialize_message(message)}]
        if not self.store:
            # init Faiss
            self.store = self._write(docs, metadatas)
            self._initialized = True
        else:
            self.store.add_texts(texts=docs, metadatas=metadatas)
        self.persist()
        logger.info(f"Agent {self.role_id}'s memory_storage add a message")

    def search(self, message: Message, k=4) -> List[Message]:
        """search for dissimilar messages"""
        if not self.store:
            return []

        resp = self.store.similarity_search_with_score(
            query=message.content,
            k=k
        )
        # filter the result which score is smaller than the threshold
        filtered_resp = []
        for item, score in resp:
            # the smaller score means more similar relation
            if score < self.threshold:
                continue
            # convert search result into Memory
            metadata = item.metadata
            new_mem = deserialize_message(metadata.get("message_ser"))
            filtered_resp.append(new_mem)
        return filtered_resp

    def clean(self):
        index_fpath, storage_fpath = self._get_index_and_store_fname()
        if index_fpath and index_fpath.exists():
            index_fpath.unlink(missing_ok=True)
        if storage_fpath and storage_fpath.exists():
            storage_fpath.unlink(missing_ok=True)

        self.store = None
        self._initialized = False