Leonydis137 commited on
Commit
bc19ff9
·
verified ·
1 Parent(s): 58a4c8c

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +35 -23
utils.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import json
3
  import faiss
4
  import numpy as np
 
5
  from datetime import datetime
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -13,7 +14,8 @@ MEMORY_INDEX_PATH = "memory.index"
13
  MEMORY_TEXTS_PATH = "memory_texts.json"
14
  CHAT_LOG_PATH = "chatlog.jsonl"
15
  FEEDBACK_PATH = "feedback.jsonl"
16
- SUMMARY_TRIGGER = 100
 
17
 
18
  # === Load models ===
19
  embedder = SentenceTransformer(EMBED_MODEL)
@@ -23,25 +25,31 @@ summary_model = AutoModelForCausalLM.from_pretrained(SUMMARIZER_MODEL).eval()
23
  embedding_dim = embedder.get_sentence_embedding_dimension()
24
 
25
  # === Memory state ===
26
- if os.path.exists(MEMORY_INDEX_PATH) and os.path.exists(MEMORY_TEXTS_PATH):
27
- memory_index = faiss.read_index(MEMORY_INDEX_PATH)
28
- with open(MEMORY_TEXTS_PATH, "r") as f:
29
- memory_texts = json.load(f)
30
- else:
 
 
 
31
  memory_index = faiss.IndexFlatL2(embedding_dim)
32
  memory_texts = []
33
 
34
- def embed(text):
35
- """Embed text into vector"""
36
- return embedder.encode([text])
37
 
38
  def add_to_memory(text):
39
  """Add a memory item"""
40
- vec = embed(text)
41
  memory_index.add(np.array(vec))
42
- memory_texts.append(text)
 
 
 
 
43
  save_memory()
44
-
45
  if len(memory_texts) > SUMMARY_TRIGGER:
46
  summarize_old_memories()
47
 
@@ -49,36 +57,40 @@ def retrieve_memories(query, k=3):
49
  """Retrieve top relevant memories"""
50
  if memory_index.ntotal == 0:
51
  return []
52
- vec = embed(query)
53
  D, I = memory_index.search(np.array(vec), k)
54
- return [memory_texts[i] for i in I[0] if i < len(memory_texts)]
55
 
56
  def save_memory():
57
  """Save FAISS and text memory to disk"""
58
  faiss.write_index(memory_index, MEMORY_INDEX_PATH)
59
  with open(MEMORY_TEXTS_PATH, "w") as f:
60
  json.dump(memory_texts, f)
 
61
 
62
  def reset_memory():
63
  """Reset memory entirely"""
64
  memory_index.reset()
65
  memory_texts.clear()
66
- if os.path.exists(MEMORY_INDEX_PATH):
67
- os.remove(MEMORY_INDEX_PATH)
68
- if os.path.exists(MEMORY_TEXTS_PATH):
69
- os.remove(MEMORY_TEXTS_PATH)
70
 
71
  def summarize_old_memories():
72
  """Replace older entries with a summary to save space"""
73
- old = "\n".join(memory_texts[:10])
74
  inputs = summary_tokenizer(f"Summarize: {old}", return_tensors="pt")
75
  output = summary_model.generate(**inputs, max_new_tokens=100)
76
  summary = summary_tokenizer.decode(output[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
77
 
78
- memory_texts[:10] = [summary]
 
 
 
 
79
  memory_index.reset()
80
- for text in memory_texts:
81
- vec = embed(text)
82
  memory_index.add(np.array(vec))
83
  save_memory()
84
 
@@ -109,4 +121,4 @@ def generate_suggestions(feedback_text):
109
  outputs = summary_model.generate(**inputs, max_new_tokens=100)
110
  suggestions = summary_tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
111
  log_feedback(feedback_text, suggestions)
112
- return suggestions
 
2
  import json
3
  import faiss
4
  import numpy as np
5
+ from uuid import uuid4
6
  from datetime import datetime
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
14
  MEMORY_TEXTS_PATH = "memory_texts.json"
15
  CHAT_LOG_PATH = "chatlog.jsonl"
16
  FEEDBACK_PATH = "feedback.jsonl"
17
+ SUMMARY_TRIGGER = int(os.getenv("SUMMARY_TRIGGER", 100))
18
+ CHUNK_SIZE = int(os.getenv("SUMMARY_CHUNK", 10))
19
 
20
  # === Load models ===
21
  embedder = SentenceTransformer(EMBED_MODEL)
 
25
  embedding_dim = embedder.get_sentence_embedding_dimension()
26
 
27
  # === Memory state ===
28
+ try:
29
+ if os.path.exists(MEMORY_INDEX_PATH) and os.path.exists(MEMORY_TEXTS_PATH):
30
+ memory_index = faiss.read_index(MEMORY_INDEX_PATH)
31
+ with open(MEMORY_TEXTS_PATH, "r") as f:
32
+ memory_texts = json.load(f)
33
+ else:
34
+ raise FileNotFoundError
35
+ except:
36
  memory_index = faiss.IndexFlatL2(embedding_dim)
37
  memory_texts = []
38
 
39
+ def embed(texts):
40
+ """Embed a list of texts into vectors"""
41
+ return embedder.encode(texts)
42
 
43
  def add_to_memory(text):
44
  """Add a memory item"""
45
+ vec = embed([text])
46
  memory_index.add(np.array(vec))
47
+ memory_texts.append({
48
+ "id": str(uuid4()),
49
+ "text": text,
50
+ "timestamp": datetime.now().isoformat()
51
+ })
52
  save_memory()
 
53
  if len(memory_texts) > SUMMARY_TRIGGER:
54
  summarize_old_memories()
55
 
 
57
  """Retrieve top relevant memories"""
58
  if memory_index.ntotal == 0:
59
  return []
60
+ vec = embed([query])
61
  D, I = memory_index.search(np.array(vec), k)
62
+ return [memory_texts[i]["text"] for i in I[0] if i < len(memory_texts)]
63
 
64
  def save_memory():
65
  """Save FAISS and text memory to disk"""
66
  faiss.write_index(memory_index, MEMORY_INDEX_PATH)
67
  with open(MEMORY_TEXTS_PATH, "w") as f:
68
  json.dump(memory_texts, f)
69
+ print(f"[INFO] Memory saved: {len(memory_texts)} items")
70
 
71
  def reset_memory():
72
  """Reset memory entirely"""
73
  memory_index.reset()
74
  memory_texts.clear()
75
+ for path in [MEMORY_INDEX_PATH, MEMORY_TEXTS_PATH]:
76
+ if os.path.exists(path):
77
+ os.remove(path)
 
78
 
79
  def summarize_old_memories():
80
  """Replace older entries with a summary to save space"""
81
+ old = "\n".join(m["text"] for m in memory_texts[:CHUNK_SIZE])
82
  inputs = summary_tokenizer(f"Summarize: {old}", return_tensors="pt")
83
  output = summary_model.generate(**inputs, max_new_tokens=100)
84
  summary = summary_tokenizer.decode(output[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
85
 
86
+ memory_texts[:CHUNK_SIZE] = [{
87
+ "id": str(uuid4()),
88
+ "text": summary,
89
+ "timestamp": datetime.now().isoformat()
90
+ }]
91
  memory_index.reset()
92
+ for mem in memory_texts:
93
+ vec = embed([mem["text"]])
94
  memory_index.add(np.array(vec))
95
  save_memory()
96
 
 
121
  outputs = summary_model.generate(**inputs, max_new_tokens=100)
122
  suggestions = summary_tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
123
  log_feedback(feedback_text, suggestions)
124
+ return suggestions