Spaces:
Running
Running
File size: 4,303 Bytes
3fdcc70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from typing import List, Optional, Dict
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
BaseMessage,
)
from .utils import count_tokens, get_max_context_length
class MessageMemory:
def __init__(
self,
max_tokens: int = -1,
margin: int = 1500,
messages: Optional[List[BaseMessage]] = None,
) -> None:
self.max_tokens = max_tokens if max_tokens > 0 else 8e8
self.margin = margin
self.init_messages(messages)
def reset(self) -> List[BaseMessage]:
self.init_messages()
return self.stored_messages
def init_messages(self, messages=None) -> None:
if messages is not None:
self.stored_messages = messages
else:
self.stored_messages = []
@classmethod
def to_messages(cls, items: List[Dict]):
messages = []
for m in items:
if (
not isinstance(m, dict)
or m.get("role", None) is None
or m.get("role") not in ["user", "assistant", "system"]
):
raise TypeError()
if m["role"] == "system":
messages.append(SystemMessage(content=m["content"]))
elif m["role"] == "user":
messages.append(HumanMessage(content=m["content"]))
elif m["role"] == "assistant":
messages.append(AIMessage(content=m["content"]))
return messages
def to_dict(self):
messages = []
for m in self.stored_messages:
if not isinstance(m, BaseMessage) or m.type is None:
raise TypeError()
if isinstance(m, SystemMessage):
messages.append({"role": "system", "content": m.content})
elif isinstance(m, HumanMessage):
messages.append({"role": "user", "content": m.content})
elif isinstance(m, AIMessage):
messages.append({"role": "assistant", "content": m.content})
return messages
def get_memory(self):
return self.stored_messages
def update_message(self, message: BaseMessage) -> List[BaseMessage]:
self.stored_messages.append(message)
return self.stored_messages
def insert_messages(
self, idx: int = 0, messages: List[BaseMessage] = None
) -> List[BaseMessage]:
for m in messages[::-1]:
self.stored_messages.insert(idx, m)
return self.stored_messages
@classmethod
def messages2str(self, history):
history_text = ""
for m in history:
if isinstance(m, SystemMessage):
history_text += "<system>: " + m.content + "\n"
elif isinstance(m, HumanMessage):
history_text += "<user>: " + m.content + "\n"
elif isinstance(m, AIMessage):
history_text += "<assistant>: " + m.content + "\n"
return history_text
def memory2str(self):
return self.messages2str(self.stored_messages)
def cut_memory(self, LLM_encoding: str):
start = 0
while start <= len(self.stored_messages):
# print(f'self.stored_messages = {self.stored_messages}')
history = self.stored_messages[start:]
history_text = self.messages2str(history)
num = count_tokens(LLM_encoding, history_text)
max_tokens = min(self.max_tokens, get_max_context_length(LLM_encoding))
if max_tokens - num > self.margin:
self.stored_messages = self.stored_messages[start:]
return self.stored_messages
start += 1
self.init_messages()
return self.stored_messages
if __name__ == "__main__":
import os
os.environ["TIKTOKEN_CACHE_DIR"] = "/mnt/petrelfs/liuzhaoyang/workspace/tmp"
messages = [
SystemMessage(content="SystemMessage 1"),
HumanMessage(content="Remember a = 5 * 4."),
AIMessage(content="SystemMessage 2"),
HumanMessage(content="what is the value of a?"),
] * 400
print(SystemMessage(content="SystemMessage 1").content)
print(len(messages))
mem = MessageMemory(
-1,
messages,
)
messages = mem.cut_memory("gpt-3.5-turbo")
print(len(messages))
|