vincentiusyoshuac commited on
Commit
254bf80
·
verified ·
1 Parent(s): ca65cde

Update network.py

Browse files
Files changed (1) hide show
  1. network.py +16 -15
network.py CHANGED
@@ -7,13 +7,13 @@ from typing import Dict, Optional
7
  from .node import CognitiveNode
8
 
9
  class DynamicCognitiveNet(nn.Module):
10
- """Jaringan dengan manajemen koneksi yang robust"""
11
  def __init__(self, input_size: int, output_size: int):
12
  super().__init__()
13
  self.input_size = input_size
14
  self.output_size = output_size
15
 
16
- # Node input/output
17
  self.input_nodes = nn.ModuleList([
18
  CognitiveNode(i, 1) for i in range(input_size)
19
  ])
@@ -25,7 +25,7 @@ class DynamicCognitiveNet(nn.Module):
25
  self.connections = nn.ParameterDict()
26
  self._init_base_connections()
27
 
28
- # Konteks pembelajaran
29
  self.emotional_state = nn.Parameter(torch.tensor(0.0))
30
  self.optimizer = optim.AdamW(self.parameters(), lr=0.001)
31
  self.loss_fn = nn.MSELoss()
@@ -40,6 +40,9 @@ class DynamicCognitiveNet(nn.Module):
40
  )
41
 
42
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 
 
43
  # Pemrosesan input
44
  activations = {}
45
  for i, node in enumerate(self.input_nodes):
@@ -61,8 +64,8 @@ class DynamicCognitiveNet(nn.Module):
61
  return torch.stack(outputs).squeeze()
62
 
63
  def structural_update(self, global_reward: float):
64
- """Update struktur dengan validasi koneksi"""
65
- # Adaptasi kekuatan koneksi
66
  for conn_id in list(self.connections.keys()):
67
  new_weight = self.connections[conn_id] + 0.1 * global_reward
68
  self.connections[conn_id].data = new_weight.clamp(-1, 1)
@@ -74,7 +77,7 @@ class DynamicCognitiveNet(nn.Module):
74
  self.connections[new_conn] = nn.Parameter(torch.randn(1) * 0.1)
75
 
76
  def _find_underutilized_connection(self) -> Optional[str]:
77
- """Mencari koneksi input-output yang underutilized"""
78
  input_act = {n.id: np.mean(n.recent_activations)
79
  for n in self.input_nodes if n.recent_activations}
80
  output_act = {n.id: np.mean(n.recent_activations)
@@ -92,10 +95,10 @@ class DynamicCognitiveNet(nn.Module):
92
  self.optimizer.zero_grad()
93
 
94
  try:
95
- pred = self(x)
96
- loss = self.loss_fn(pred, y)
97
- except RuntimeError as e:
98
- print(f"Error selama forward pass: {e}")
99
  return float('nan')
100
 
101
  # Regularisasi struktural
@@ -105,16 +108,14 @@ class DynamicCognitiveNet(nn.Module):
105
  try:
106
  total_loss.backward()
107
  self.optimizer.step()
108
- except RuntimeError as e:
109
- print(f"Error selama backpropagation: {e}")
110
  return float('nan')
111
 
112
- # Update state emosional
113
  self.emotional_state.data = torch.sigmoid(
114
  self.emotional_state + (0.5 - loss.item()) * 0.1
115
  )
116
-
117
- # Update struktur
118
  self.structural_update(0.5 - loss.item())
119
 
120
  return total_loss.item()
 
7
  from .node import CognitiveNode
8
 
9
  class DynamicCognitiveNet(nn.Module):
10
+ """Arsitektur jaringan dengan manajemen tensor yang robust"""
11
  def __init__(self, input_size: int, output_size: int):
12
  super().__init__()
13
  self.input_size = input_size
14
  self.output_size = output_size
15
 
16
+ # Node dengan input size 1
17
  self.input_nodes = nn.ModuleList([
18
  CognitiveNode(i, 1) for i in range(input_size)
19
  ])
 
25
  self.connections = nn.ParameterDict()
26
  self._init_base_connections()
27
 
28
+ # Sistem pembelajaran
29
  self.emotional_state = nn.Parameter(torch.tensor(0.0))
30
  self.optimizer = optim.AdamW(self.parameters(), lr=0.001)
31
  self.loss_fn = nn.MSELoss()
 
40
  )
41
 
42
  def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ # Validasi dimensi input
44
+ x = x.view(-1)
45
+
46
  # Pemrosesan input
47
  activations = {}
48
  for i, node in enumerate(self.input_nodes):
 
64
  return torch.stack(outputs).squeeze()
65
 
66
  def structural_update(self, global_reward: float):
67
+ """Update struktur jaringan"""
68
+ # Update kekuatan koneksi
69
  for conn_id in list(self.connections.keys()):
70
  new_weight = self.connections[conn_id] + 0.1 * global_reward
71
  self.connections[conn_id].data = new_weight.clamp(-1, 1)
 
77
  self.connections[new_conn] = nn.Parameter(torch.randn(1) * 0.1)
78
 
79
  def _find_underutilized_connection(self) -> Optional[str]:
80
+ """Mencari pasangan node yang kurang aktif"""
81
  input_act = {n.id: np.mean(n.recent_activations)
82
  for n in self.input_nodes if n.recent_activations}
83
  output_act = {n.id: np.mean(n.recent_activations)
 
95
  self.optimizer.zero_grad()
96
 
97
  try:
98
+ pred = self(x.view(-1))
99
+ loss = self.loss_fn(pred, y.view(-1))
100
+ except Exception as e:
101
+ print(f"Error forward: {e}")
102
  return float('nan')
103
 
104
  # Regularisasi struktural
 
108
  try:
109
  total_loss.backward()
110
  self.optimizer.step()
111
+ except Exception as e:
112
+ print(f"Error backward: {e}")
113
  return float('nan')
114
 
115
+ # Update emosi
116
  self.emotional_state.data = torch.sigmoid(
117
  self.emotional_state + (0.5 - loss.item()) * 0.1
118
  )
 
 
119
  self.structural_update(0.5 - loss.item())
120
 
121
  return total_loss.item()