File size: 2,184 Bytes
b668479 ee91c59 b668479 ee91c59 b668479 ee91c59 b668479 ee91c59 b668479 c94bd82 b668479 c94bd82 ee91c59 b668479 ee91c59 b668479 ee91c59 b668479 ee91c59 b668479 ee91c59 b668479 ee91c59 b668479 ee91c59 b668479 ee91c59 b668479 ee91c59 b668479 ee91c59 b668479 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# cognitive_net/memory.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from typing import Deque, Dict, Any
class CognitiveMemory(nn.Module):
"""Differentiable memory system with biological consolidation mechanisms"""
def __init__(self, context_size: int, capacity: int = 100):
super().__init__()
self.context_size = context_size
self.capacity = capacity
self.memory_queue: Deque[Dict[str, Any]] = deque(maxlen=capacity)
# Memory projection layers with adaptive scaling
self.key_proj = nn.Linear(context_size, 64)
self.value_proj = nn.Linear(context_size, 64)
self.importance_decay = nn.Parameter(torch.tensor(0.95))
# Consolidation parameters
self.consolidation_threshold = 0.7
self.age_decay = 0.1
def add_memory(self, context: torch.Tensor, activation: float):
"""Store memory with dynamic importance weighting"""
importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
self.memory_queue.append({
'context': context.detach().clone(),
'importance': importance,
'age': torch.tensor(0.0)
})
def consolidate_memories(self):
"""Memory optimization through importance-based pruning"""
new_queue = deque(maxlen=self.capacity)
for mem in self.memory_queue:
mem['importance'] *= self.importance_decay
mem['age'] += self.age_decay
if mem['importance'] > 0.2:
new_queue.append(mem)
self.memory_queue = new_queue
def retrieve(self, query: torch.Tensor) -> torch.Tensor:
"""Content-based memory retrieval with attention"""
if not self.memory_queue:
return torch.zeros(64, device=query.device)
contexts = torch.stack([m['context'] for m in self.memory_queue])
keys = self.key_proj(contexts)
values = self.value_proj(contexts)
query_proj = self.key_proj(query.unsqueeze(0))
scores = F.softmax(keys @ query_proj.T, dim=0)
return (scores * values).sum(dim=0) |