vincentiusyoshuac commited on
Commit
bd035eb
·
verified ·
1 Parent(s): 8d8fb56

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ from cognitive_net import DynamicCognitiveNet
6
+
7
+ class ModelDemo:
8
+ def __init__(self):
9
+ self.net = DynamicCognitiveNet(input_size=5, output_size=1)
10
+ self.training_history = []
11
+ self.emotional_history = []
12
+
13
+ def train_sequence(self, sequence_str, epochs):
14
+ """Train model on input sequence and return visualizations"""
15
+ try:
16
+ # Parse input sequence
17
+ sequence = [float(x.strip()) for x in sequence_str.split(',')]
18
+ if len(sequence) < 6:
19
+ return "Error: Please input at least 6 numbers", None, None
20
+
21
+ # Prepare training data
22
+ X = torch.tensor(sequence[:-1]).float()
23
+ y = torch.tensor([sequence[-1]]).float()
24
+
25
+ # Training loop
26
+ losses = []
27
+ emotions = []
28
+ for epoch in range(epochs):
29
+ loss = self.net.train_step(X, y)
30
+ losses.append(loss)
31
+ emotions.append(self.net.emotional_state.item())
32
+
33
+ # Create loss plot
34
+ loss_fig = go.Figure()
35
+ loss_fig.add_trace(go.Scatter(y=losses, name='Loss'))
36
+ loss_fig.update_layout(title='Training Loss',
37
+ xaxis_title='Epoch',
38
+ yaxis_title='Loss')
39
+
40
+ # Create emotion plot
41
+ emotion_fig = go.Figure()
42
+ emotion_fig.add_trace(go.Scatter(y=emotions, name='Emotional State'))
43
+ emotion_fig.update_layout(title='Emotional State',
44
+ xaxis_title='Epoch',
45
+ yaxis_title='Value')
46
+
47
+ # Make prediction
48
+ with torch.no_grad():
49
+ pred = self.net(X)
50
+ result = f"Prediction: {pred.item():.4f} (Target: {y.item():.4f})"
51
+
52
+ return result, loss_fig, emotion_fig
53
+
54
+ except Exception as e:
55
+ return f"Error: {str(e)}", None, None
56
+
57
+ def visualize_memory(self):
58
+ """Visualize memory importance weights"""
59
+ memories = []
60
+ importances = []
61
+
62
+ for mem in self.net.nodes['input_0'].memory.memory_queue:
63
+ memories.append(mem['context'].numpy())
64
+ importances.append(mem['importance'].item())
65
+
66
+ if not memories:
67
+ return "No memories stored yet"
68
+
69
+ fig = go.Figure()
70
+ fig.add_trace(go.Bar(y=importances))
71
+ fig.update_layout(title='Memory Importance',
72
+ xaxis_title='Memory Index',
73
+ yaxis_title='Importance')
74
+ return fig
75
+
76
+ # Initialize demo
77
+ demo = ModelDemo()
78
+
79
+ # Create Gradio interface
80
+ with gr.Blocks(title="Cognitive Network Demo") as iface:
81
+ gr.Markdown("""
82
+ # Cognitive Network Interactive Demo
83
+
84
+ This demo shows a neural network with:
85
+ - Dynamic memory
86
+ - Emotional modulation
87
+ - Adaptive structure
88
+
89
+ Enter a sequence of numbers (comma-separated) to train the model to predict the next number.
90
+ """)
91
+
92
+ with gr.Row():
93
+ with gr.Column():
94
+ input_seq = gr.Textbox(label="Input Sequence (comma-separated)",
95
+ value="1, 2, 3, 4, 5, 6")
96
+ epochs = gr.Slider(minimum=10, maximum=500, value=100,
97
+ step=10, label="Training Epochs")
98
+ train_btn = gr.Button("Train Model")
99
+
100
+ result_text = gr.Textbox(label="Prediction Result")
101
+
102
+ with gr.Row():
103
+ loss_plot = gr.Plot(label="Training Loss")
104
+ emotion_plot = gr.Plot(label="Emotional State")
105
+
106
+ with gr.Row():
107
+ memory_btn = gr.Button("Visualize Memory")
108
+ memory_plot = gr.Plot(label="Memory Importance")
109
+
110
+ # Connect components
111
+ train_btn.click(
112
+ fn=demo.train_sequence,
113
+ inputs=[input_seq, epochs],
114
+ outputs=[result_text, loss_plot, emotion_plot]
115
+ )
116
+
117
+ memory_btn.click(
118
+ fn=demo.visualize_memory,
119
+ inputs=None,
120
+ outputs=memory_plot
121
+ )
122
+
123
+ # Launch app
124
+ iface.launch()