File size: 2,184 Bytes
b668479
ee91c59
 
 
 
b668479
ee91c59
 
b668479
ee91c59
 
 
 
b668479
ee91c59
b668479
c94bd82
 
b668479
 
 
 
 
c94bd82
ee91c59
b668479
ee91c59
 
b668479
ee91c59
b668479
ee91c59
b668479
ee91c59
b668479
 
ee91c59
 
b668479
 
 
 
 
ee91c59
b668479
ee91c59
b668479
ee91c59
b668479
 
 
 
ee91c59
b668479
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# cognitive_net/memory.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from typing import Deque, Dict, Any

class CognitiveMemory(nn.Module):
    """Differentiable memory system with biological consolidation mechanisms"""
    def __init__(self, context_size: int, capacity: int = 100):
        super().__init__()
        self.context_size = context_size
        self.capacity = capacity
        self.memory_queue: Deque[Dict[str, Any]] = deque(maxlen=capacity)
        
        # Memory projection layers with adaptive scaling
        self.key_proj = nn.Linear(context_size, 64)
        self.value_proj = nn.Linear(context_size, 64)
        self.importance_decay = nn.Parameter(torch.tensor(0.95))
        
        # Consolidation parameters
        self.consolidation_threshold = 0.7
        self.age_decay = 0.1

    def add_memory(self, context: torch.Tensor, activation: float):
        """Store memory with dynamic importance weighting"""
        importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
        self.memory_queue.append({
            'context': context.detach().clone(),
            'importance': importance,
            'age': torch.tensor(0.0)
        })

    def consolidate_memories(self):
        """Memory optimization through importance-based pruning"""
        new_queue = deque(maxlen=self.capacity)
        for mem in self.memory_queue:
            mem['importance'] *= self.importance_decay
            mem['age'] += self.age_decay
            if mem['importance'] > 0.2:
                new_queue.append(mem)
        self.memory_queue = new_queue

    def retrieve(self, query: torch.Tensor) -> torch.Tensor:
        """Content-based memory retrieval with attention"""
        if not self.memory_queue:
            return torch.zeros(64, device=query.device)
            
        contexts = torch.stack([m['context'] for m in self.memory_queue])
        keys = self.key_proj(contexts)
        values = self.value_proj(contexts)
        query_proj = self.key_proj(query.unsqueeze(0))
        
        scores = F.softmax(keys @ query_proj.T, dim=0)
        return (scores * values).sum(dim=0)