Leonydis137 commited on
Commit
23150a0
·
verified ·
1 Parent(s): fc41652

Update memory.py

Browse files
Files changed (1) hide show
  1. memory.py +16 -17
memory.py CHANGED
@@ -1,22 +1,21 @@
1
- import os
2
  from transformers import AutoTokenizer, AutoModel
 
 
3
 
4
  class MemoryVectorStore:
5
- def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
6
- self.cache_dir = "/tmp/hf_cache"
7
- os.makedirs(self.cache_dir, exist_ok=False)
8
-
9
- self.tokenizer = AutoTokenizer.from_pretrained(
10
- model_name, cache_dir=self.cache_dir
11
- )
12
- self.model = AutoModel.from_pretrained(
13
- model_name, cache_dir=self.cache_dir
14
- )
15
-
16
- self.memory = []
17
 
18
- def add(self, text):
19
- self.memory.append(text)
 
 
20
 
21
- def retrieve(self, query):
22
- return self.memory
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+ import os
4
 
5
  class MemoryVectorStore:
6
+ def __init__(self):
7
+ # Download model to project dir
8
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./hf-cache")
10
+ self.model = AutoModel.from_pretrained(model_name, cache_dir="./hf-cache")
 
 
 
 
 
 
 
11
 
12
+ def encode(self, texts):
13
+ # Ensure list of texts
14
+ if isinstance(texts, str):
15
+ texts = [texts]
16
 
17
+ with torch.no_grad():
18
+ inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
19
+ outputs = self.model(**inputs)
20
+ embeddings = outputs.last_hidden_state.mean(dim=1) # Mean pooling
21
+ return embeddings.numpy()