Commit
·
5071898
1
Parent(s):
5238878
Upload memory_func.py
Browse filesdeal with memory capacity
models_for_langchain/memory_func.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.llms.base import LLM
|
2 |
+
from langchain.memory import ConversationBufferWindowMemory
|
3 |
+
from transformers import GPT2TokenizerFast
|
4 |
+
from langchain.schema.messages import get_buffer_string
|
5 |
+
|
6 |
+
def get_num_tokens(text):
|
7 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
8 |
+
return len(tokenizer.tokenize(text))
|
9 |
+
|
10 |
+
def get_memory_num_tokens(memory):
|
11 |
+
buffer = memory.chat_memory.messages
|
12 |
+
return sum([get_num_tokens(get_buffer_string([m])) for m in buffer])
|
13 |
+
|
14 |
+
def validate_memory_len(memory, max_token_limit=2000):
|
15 |
+
buffer = memory.chat_memory.messages
|
16 |
+
curr_buffer_length = get_memory_num_tokens(memory)
|
17 |
+
if curr_buffer_length > max_token_limit:
|
18 |
+
while curr_buffer_length > max_token_limit:
|
19 |
+
buffer.pop(0)
|
20 |
+
curr_buffer_length = get_memory_num_tokens(memory)
|
21 |
+
return memory
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
25 |
+
text = '''Hi'''
|
26 |
+
print(len(tokenizer.tokenize(text)))
|