cognitive_net / network.py
vincentiusyoshuac's picture
Update network.py
e25b866 verified
raw
history blame
4.44 kB
import torch
import torch.nn as nn
import torch.optim as optim
import math
import numpy as np
from typing import Dict, Optional
from .node import CognitiveNode
class DynamicCognitiveNet(nn.Module):
"""Jaringan dengan manajemen koneksi yang robust"""
def __init__(self, input_size: int, output_size: int):
super().__init__()
self.input_size = input_size
self.output_size = output_size
# Node input/output
self.input_nodes = nn.ModuleList([
CognitiveNode(i, 1) for i in range(input_size)
])
self.output_nodes = nn.ModuleList([
CognitiveNode(input_size + i, 1) for i in range(output_size)
])
# Manajemen koneksi
self.connections = nn.ParameterDict()
self._init_base_connections()
# Konteks pembelajaran
self.emotional_state = nn.Parameter(torch.tensor(0.0))
self.optimizer = optim.AdamW(self.parameters(), lr=0.001)
self.loss_fn = nn.MSELoss()
def _init_base_connections(self):
"""Inisialisasi koneksi input-output"""
for in_node in self.input_nodes:
for out_node in self.output_nodes:
conn_id = f"{in_node.id}->{out_node.id}"
self.connections[conn_id] = nn.Parameter(
torch.randn(1) * 0.1
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Pemrosesan input
activations = {}
for i, node in enumerate(self.input_nodes):
activations[node.id] = node(x[i].unsqueeze(0))
# Integrasi output
outputs = []
for out_node in self.output_nodes:
integrated = []
for in_node in self.input_nodes:
conn_id = f"{in_node.id}->{out_node.id}"
weight = torch.sigmoid(self.connections[conn_id])
integrated.append(activations[in_node.id] * weight)
if integrated:
combined = sum(integrated) / math.sqrt(len(integrated))
outputs.append(out_node(combined))
return torch.stack(outputs).squeeze()
def structural_update(self, global_reward: float):
"""Update struktur dengan validasi koneksi"""
# Adaptasi kekuatan koneksi
for conn_id in list(self.connections.keys()):
new_weight = self.connections[conn_id] + 0.1 * global_reward
self.connections[conn_id].data = new_weight.clamp(-1, 1)
# Pembuatan koneksi baru
if global_reward < -0.5:
new_conn = self._find_underutilized_connection()
if new_conn and new_conn not in self.connections:
self.connections[new_conn] = nn.Parameter(torch.randn(1) * 0.1)
def _find_underutilized_connection(self) -> Optional[str]:
"""Mencari koneksi input-output yang underutilized"""
input_act = {n.id: np.mean(n.recent_activations)
for n in self.input_nodes if n.recent_activations}
output_act = {n.id: np.mean(n.recent_activations)
for n in self.output_nodes if n.recent_activations}
if not input_act or not output_act:
return None
src = min(input_act, key=lambda k: input_act[k])
tgt = min(output_act, key=lambda k: output_act[k])
return f"{src}->{tgt}"
def train_step(self, x: torch.Tensor, y: torch.Tensor) -> float:
"""Training step dengan error handling"""
self.optimizer.zero_grad()
try:
pred = self(x)
loss = self.loss_fn(pred, y)
except RuntimeError as e:
print(f"Error selama forward pass: {e}")
return float('nan')
# Regularisasi struktural
reg_loss = sum(p.abs().mean() for p in self.connections.values())
total_loss = loss + 0.01 * reg_loss
try:
total_loss.backward()
self.optimizer.step()
except RuntimeError as e:
print(f"Error selama backpropagation: {e}")
return float('nan')
# Update state emosional
self.emotional_state.data = torch.sigmoid(
self.emotional_state + (0.5 - loss.item()) * 0.1
)
# Update struktur
self.structural_update(0.5 - loss.item())
return total_loss.item()