File size: 2,354 Bytes
ee91c59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c94bd82
 
 
ee91c59
 
e476572
 
ee91c59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e476572
 
 
 
 
 
ee91c59
 
e476572
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from typing import Dict, List, Optional, Tuple

class CognitiveMemory(nn.Module):
    """Differentiable memory system with consolidation and retrieval"""
    def __init__(self, context_size: int, capacity: int = 100):
        super().__init__()
        self.context_size = context_size
        self.capacity = capacity
        self.memory_queue = deque(maxlen=capacity)
        
        # Memory importance parameters
        self.importance_decay = nn.Parameter(torch.tensor(0.95))
        self.consolidation_threshold = 0.7
        
        # Memory projection layers
        self.key_proj = nn.Linear(context_size, 64)
        self.value_proj = nn.Linear(context_size, 64)

    def add_memory(self, context: torch.Tensor, activation: float):
        """Store new memory with adaptive importance"""
        # Ensure context is 1D tensor with single value
        context = context.reshape(-1)
        importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
        self.memory_queue.append({
            'context': context.detach(),
            'importance': importance,
            'age': 0.0
        })
        
    def consolidate_memories(self):
        """Memory consolidation through importance reweighting"""
        for mem in self.memory_queue:
            mem['importance'] *= self.importance_decay
            mem['age'] += 0.1
            
        # Remove unimportant memories
        self.memory_queue = deque(
            [m for m in self.memory_queue if m['importance'] > 0.2],
            maxlen=self.capacity
        )
        
    def retrieve(self, query: torch.Tensor) -> torch.Tensor:
        """Attention-based memory retrieval"""
        if not self.memory_queue:
            return torch.zeros_like(query)
            
        # Ensure query is 1D tensor with single value
        query = query.reshape(1, 1)
        memories = torch.stack([m['context'].reshape(1, 1) for m in self.memory_queue])
        
        keys = self.key_proj(memories)
        values = self.value_proj(memories)
        query_proj = self.key_proj(query)
        
        scores = F.softmax(torch.matmul(keys, query_proj.transpose(0, 1)), dim=0)
        retrieved = torch.matmul(scores.transpose(0, 1), values)
        return retrieved.squeeze(0)