vincentiusyoshuac commited on
Commit
e476572
·
verified ·
1 Parent(s): 655630b

Update memory.py

Browse files
Files changed (1) hide show
  1. memory.py +13 -6
memory.py CHANGED
@@ -17,11 +17,13 @@ class CognitiveMemory(nn.Module):
17
  self.consolidation_threshold = 0.7
18
 
19
  # Memory projection layers
20
- self.key_proj = nn.Linear(context_size, 64)
21
- self.value_proj = nn.Linear(context_size, 64)
22
 
23
  def add_memory(self, context: torch.Tensor, activation: float):
24
  """Store new memory with adaptive importance"""
 
 
25
  importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
26
  self.memory_queue.append({
27
  'context': context.detach(),
@@ -46,9 +48,14 @@ class CognitiveMemory(nn.Module):
46
  if not self.memory_queue:
47
  return torch.zeros_like(query)
48
 
49
- keys = torch.stack([self.key_proj(m['context']) for m in self.memory_queue])
50
- values = torch.stack([self.value_proj(m['context']) for m in self.memory_queue])
 
 
 
 
51
  query_proj = self.key_proj(query)
52
 
53
- scores = F.softmax(keys @ query_proj.t(), dim=0)
54
- return (scores * values).sum(dim=0)
 
 
17
  self.consolidation_threshold = 0.7
18
 
19
  # Memory projection layers
20
+ self.key_proj = nn.Linear(1, 64) # Changed from context_size to 1
21
+ self.value_proj = nn.Linear(1, 64) # Changed from context_size to 1
22
 
23
  def add_memory(self, context: torch.Tensor, activation: float):
24
  """Store new memory with adaptive importance"""
25
+ # Ensure context is 1D tensor with single value
26
+ context = context.reshape(-1)
27
  importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
28
  self.memory_queue.append({
29
  'context': context.detach(),
 
48
  if not self.memory_queue:
49
  return torch.zeros_like(query)
50
 
51
+ # Ensure query is 1D tensor with single value
52
+ query = query.reshape(1, 1)
53
+ memories = torch.stack([m['context'].reshape(1, 1) for m in self.memory_queue])
54
+
55
+ keys = self.key_proj(memories)
56
+ values = self.value_proj(memories)
57
  query_proj = self.key_proj(query)
58
 
59
+ scores = F.softmax(torch.matmul(keys, query_proj.transpose(0, 1)), dim=0)
60
+ retrieved = torch.matmul(scores.transpose(0, 1), values)
61
+ return retrieved.squeeze(0)