File size: 2,091 Bytes
ee91c59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from typing import Dict, List, Optional, Tuple

class CognitiveMemory(nn.Module):
    """Differentiable memory system with consolidation and retrieval"""
    def __init__(self, context_size: int, capacity: int = 100):
        super().__init__()
        self.context_size = context_size
        self.capacity = capacity
        self.memory_queue = deque(maxlen=capacity)
        
        # Memory importance parameters
        self.importance_decay = nn.Parameter(torch.tensor(0.95))
        self.consolidation_threshold = 0.7
        
        # Memory projection layers
        self.key_proj = nn.Linear(context_size, 64)
        self.value_proj = nn.Linear(context_size, 64)
        
    def add_memory(self, context: torch.Tensor, activation: float):
        """Store new memory with adaptive importance"""
        importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
        self.memory_queue.append({
            'context': context.detach(),
            'importance': importance,
            'age': 0.0
        })
        
    def consolidate_memories(self):
        """Memory consolidation through importance reweighting"""
        for mem in self.memory_queue:
            mem['importance'] *= self.importance_decay
            mem['age'] += 0.1
            
        # Remove unimportant memories
        self.memory_queue = deque(
            [m for m in self.memory_queue if m['importance'] > 0.2],
            maxlen=self.capacity
        )
        
    def retrieve(self, query: torch.Tensor) -> torch.Tensor:
        """Attention-based memory retrieval"""
        if not self.memory_queue:
            return torch.zeros_like(query)
            
        keys = torch.stack([self.key_proj(m['context']) for m in self.memory_queue])
        values = torch.stack([self.value_proj(m['context']) for m in self.memory_queue])
        query_proj = self.key_proj(query)
        
        scores = F.softmax(keys @ query_proj.t(), dim=0)
        return (scores * values).sum(dim=0)