MartialTerran commited on
Commit
25652b9
·
verified ·
1 Parent(s): 93d3a73

Create GreaterThan_MLP_V1.1_with_FailuresAnalysis.ipynb

Browse files
GreaterThan_MLP_V1.1_with_FailuresAnalysis.ipynb ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GreaterThan_MLP_V1.1_with_FailuresAnalysis.py
2
+ """
3
+ The objective of GreaterThan_MLP_V1.0.py is to establish a fundamental performance baseline
4
+ for a numerical comparison task using a deliberately simple Multi-Layer Perceptron (MLP).
5
+ It avoids all natural language processing techniques by treating the problem as a pure binary classification
6
+ on a fixed-size vector. The dataset consists of synthetically generated pairs of
7
+ two-digit decimal numbers (e.g., 10.00 and 09.21),
8
+ which are deconstructed and flattened into an 8-dimensional feature vector of their raw digits
9
+ ([1, 0, 0, 0,
10
+ 0, 9, 2, 1]).
11
+ The model is then trained to predict a single binary label (0 for left > right, 1 for right > left),
12
+ directly testing the MLP's capability to learn the hierarchical rules of numerical magnitude
13
+ from the positional values of the input digits alone.
14
+
15
+ The MLP model's task is to learn the rules of numerical magnitude from raw digits alone,
16
+ treating the problem as a simple binary classification task.
17
+ It's designed for maximum clarity and serves as a fundamental baseline for this reasoning problem.
18
+ The plan is clear: a simple MLP for binary classification.
19
+ The 8-dimensional input vector, constructed from the two 4-digit numbers, will be the focus.
20
+ The output will cleanly indicate which number is greater. Using on-the-fly data generation.
21
+ The generate_mlp_data function produces the correct 8-dimensional input vectors and binary labels.
22
+ GreaterThan_MLP_V1.0.py presents a basic numerical comparison challenge using a rudimentary MLP as a baseline.
23
+ The core approach hinges on framing the task as a binary classification problem on a fixed-length feature vector.
24
+ Pairs of decimal numbers are converted into an 8-dimensional array of their digit values;
25
+ for instance, 10.00 and 09.21 are transformed to [1, 0, 0, 0, 0, 9, 2, 1].
26
+ The model's training focuses on predicting whether one number is greater than another through a single binary label.
27
+
28
+ The MLP baseline model performs remarkably well, achieving over 99.9% accuracy in deciding "GreaterThan".
29
+ This indicates that the underlying logic of numerical comparison can be learned from raw digits by a simple neural network,
30
+ provided the input is structured as a fixed-size vector.
31
+ However, even with high accuracy, failures still occur. Understanding why and on what data the model fails
32
+ is the next critical step in ML engineering. This is how we discover dataset biases, edge cases, and architectural weaknesses.
33
+ Here is the modified script, GreaterThan_MLP_V1.1_with_FailuresAnalysis.py
34
+ It incorporates to automatically detect and log failures to a CSV file when accuracy is high,
35
+ creating a valuable dataset artifact for future analysis and the development of more robust models.
36
+
37
+ Here is the ouput of the first demonstration run in colab:
38
+ Model initialized with 9473 parameters.
39
+
40
+ --- Starting Training ---
41
+ Epoch [1/100], Train Loss: 0.4015, Train Acc: 82.67%, | Val Loss: 0.1690, Val Acc: 97.03%
42
+ Epoch [2/100], Train Loss: 0.1743, Train Acc: 92.94%, | Val Loss: 0.0974, Val Acc: 98.04%
43
+ Epoch [3/100], Train Loss: 0.1300, Train Acc: 94.54%, | Val Loss: 0.0741, Val Acc: 98.61%
44
+ Epoch [4/100], Train Loss: 0.1112, Train Acc: 95.20%, | Val Loss: 0.0618, Val Acc: 98.96%
45
+ Epoch [5/100], Train Loss: 0.1019, Train Acc: 95.61%, | Val Loss: 0.0565, Val Acc: 98.79%
46
+ Epoch [6/100], Train Loss: 0.0926, Train Acc: 96.04%, | Val Loss: 0.0498, Val Acc: 99.10%
47
+ -> High accuracy detected. Scanning for failures...
48
+ -> Logged 607 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
49
+ -> Logged 180 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
50
+ Epoch [7/100], Train Loss: 0.0857, Train Acc: 96.33%, | Val Loss: 0.0456, Val Acc: 99.19%
51
+ -> High accuracy detected. Scanning for failures...
52
+ -> Logged 562 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
53
+ -> Logged 161 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
54
+ Epoch [8/100], Train Loss: 0.0827, Train Acc: 96.47%, | Val Loss: 0.0430, Val Acc: 99.14%
55
+ -> High accuracy detected. Scanning for failures...
56
+ -> Logged 538 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
57
+ -> Logged 171 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
58
+ Epoch [9/100], Train Loss: 0.0767, Train Acc: 96.73%, | Val Loss: 0.0398, Val Acc: 99.33%
59
+ -> High accuracy detected. Scanning for failures...
60
+ -> Logged 462 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
61
+ -> Logged 133 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
62
+ Epoch [10/100], Train Loss: 0.0727, Train Acc: 96.87%, | Val Loss: 0.0376, Val Acc: 99.33%
63
+ -> High accuracy detected. Scanning for failures...
64
+ -> Logged 457 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
65
+ -> Logged 134 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
66
+ Epoch [11/100], Train Loss: 0.0692, Train Acc: 97.04%, | Val Loss: 0.0380, Val Acc: 99.06%
67
+ -> High accuracy detected. Scanning for failures...
68
+ -> Logged 703 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
69
+ -> Logged 189 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
70
+ Epoch [12/100], Train Loss: 0.0665, Train Acc: 97.17%, | Val Loss: 0.0333, Val Acc: 99.42%
71
+ -> High accuracy detected. Scanning for failures...
72
+ -> Logged 365 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
73
+ -> Logged 117 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
74
+ Epoch [13/100], Train Loss: 0.0619, Train Acc: 97.36%, | Val Loss: 0.0316, Val Acc: 99.42%
75
+ -> High accuracy detected. Scanning for failures...
76
+ -> Logged 396 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
77
+ -> Logged 115 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
78
+ Epoch [14/100], Train Loss: 0.0599, Train Acc: 97.46%, | Val Loss: 0.0301, Val Acc: 99.41%
79
+ -> High accuracy detected. Scanning for failures...
80
+ -> Logged 397 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
81
+ -> Logged 119 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
82
+ Epoch [15/100], Train Loss: 0.0568, Train Acc: 97.63%, | Val Loss: 0.0282, Val Acc: 99.47%
83
+ -> High accuracy detected. Scanning for failures...
84
+ -> Logged 359 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
85
+ -> Logged 107 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
86
+ Epoch [16/100], Train Loss: 0.0550, Train Acc: 97.72%, | Val Loss: 0.0266, Val Acc: 99.53%
87
+ -> High accuracy detected. Scanning for failures...
88
+ -> Logged 331 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
89
+ -> Logged 94 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
90
+ Epoch [17/100], Train Loss: 0.0524, Train Acc: 97.80%, | Val Loss: 0.0256, Val Acc: 99.55%
91
+ -> High accuracy detected. Scanning for failures...
92
+ -> Logged 321 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
93
+ -> Logged 91 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
94
+ Epoch [18/100], Train Loss: 0.0504, Train Acc: 97.93%, | Val Loss: 0.0240, Val Acc: 99.56%
95
+ -> High accuracy detected. Scanning for failures...
96
+ -> Logged 290 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
97
+ -> Logged 87 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
98
+ Epoch [19/100], Train Loss: 0.0472, Train Acc: 98.04%, | Val Loss: 0.0228, Val Acc: 99.53%
99
+ -> High accuracy detected. Scanning for failures...
100
+ -> Logged 288 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
101
+ -> Logged 93 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
102
+ Epoch [20/100], Train Loss: 0.0447, Train Acc: 98.16%, | Val Loss: 0.0216, Val Acc: 99.61%
103
+ -> High accuracy detected. Scanning for failures...
104
+ -> Logged 289 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
105
+ -> Logged 78 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
106
+ Epoch [21/100], Train Loss: 0.0445, Train Acc: 98.12%, | Val Loss: 0.0201, Val Acc: 99.69%
107
+ -> High accuracy detected. Scanning for failures...
108
+ -> Logged 240 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
109
+ -> Logged 63 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
110
+ Epoch [22/100], Train Loss: 0.0412, Train Acc: 98.29%, | Val Loss: 0.0191, Val Acc: 99.65%
111
+ -> High accuracy detected. Scanning for failures...
112
+ -> Logged 227 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
113
+ -> Logged 70 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
114
+ Epoch [23/100], Train Loss: 0.0395, Train Acc: 98.35%, | Val Loss: 0.0181, Val Acc: 99.65%
115
+ -> High accuracy detected. Scanning for failures...
116
+ -> Logged 236 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
117
+ -> Logged 70 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
118
+ Epoch [24/100], Train Loss: 0.0373, Train Acc: 98.48%, | Val Loss: 0.0170, Val Acc: 99.71%
119
+ -> High accuracy detected. Scanning for failures...
120
+ -> Logged 209 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
121
+ -> Logged 58 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
122
+ Epoch [25/100], Train Loss: 0.0362, Train Acc: 98.53%, | Val Loss: 0.0164, Val Acc: 99.68%
123
+ -> High accuracy detected. Scanning for failures...
124
+ -> Logged 222 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
125
+ -> Logged 64 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
126
+ Epoch [26/100], Train Loss: 0.0345, Train Acc: 98.61%, | Val Loss: 0.0153, Val Acc: 99.73%
127
+ -> High accuracy detected. Scanning for failures...
128
+ -> Logged 199 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
129
+ -> Logged 53 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
130
+ Epoch [27/100], Train Loss: 0.0317, Train Acc: 98.74%, | Val Loss: 0.0149, Val Acc: 99.61%
131
+ -> High accuracy detected. Scanning for failures...
132
+ -> Logged 253 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
133
+ -> Logged 78 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
134
+ Epoch [28/100], Train Loss: 0.0302, Train Acc: 98.80%, | Val Loss: 0.0134, Val Acc: 99.80%
135
+ -> High accuracy detected. Scanning for failures...
136
+ -> Logged 162 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
137
+ -> Logged 40 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
138
+ Epoch [29/100], Train Loss: 0.0299, Train Acc: 98.80%, | Val Loss: 0.0127, Val Acc: 99.77%
139
+ -> High accuracy detected. Scanning for failures...
140
+ -> Logged 163 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
141
+ -> Logged 46 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
142
+ Epoch [30/100], Train Loss: 0.0261, Train Acc: 98.98%, | Val Loss: 0.0125, Val Acc: 99.68%
143
+ -> High accuracy detected. Scanning for failures...
144
+ -> Logged 240 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
145
+ -> Logged 64 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
146
+ Epoch [31/100], Train Loss: 0.0251, Train Acc: 99.05%, | Val Loss: 0.0110, Val Acc: 99.84%
147
+ -> High accuracy detected. Scanning for failures...
148
+ -> Logged 135 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
149
+ -> Logged 32 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
150
+ Epoch [32/100], Train Loss: 0.0246, Train Acc: 99.01%, | Val Loss: 0.0108, Val Acc: 99.78%
151
+ -> High accuracy detected. Scanning for failures...
152
+ -> Logged 167 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
153
+ -> Logged 43 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
154
+ Epoch [33/100], Train Loss: 0.0237, Train Acc: 99.07%, | Val Loss: 0.0103, Val Acc: 99.83%
155
+ -> High accuracy detected. Scanning for failures...
156
+ -> Logged 121 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
157
+ -> Logged 34 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
158
+ Epoch [34/100], Train Loss: 0.0224, Train Acc: 99.14%, | Val Loss: 0.0096, Val Acc: 99.86%
159
+ -> High accuracy detected. Scanning for failures...
160
+ -> Logged 127 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
161
+ -> Logged 29 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
162
+ Epoch [35/100], Train Loss: 0.0220, Train Acc: 99.15%, | Val Loss: 0.0092, Val Acc: 99.89%
163
+ -> High accuracy detected. Scanning for failures...
164
+ -> Logged 100 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
165
+ -> Logged 23 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
166
+ Epoch [36/100], Train Loss: 0.0204, Train Acc: 99.22%, | Val Loss: 0.0090, Val Acc: 99.83%
167
+ -> High accuracy detected. Scanning for failures...
168
+ -> Logged 126 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
169
+ -> Logged 34 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
170
+ Epoch [37/100], Train Loss: 0.0194, Train Acc: 99.25%, | Val Loss: 0.0083, Val Acc: 99.89%
171
+ -> High accuracy detected. Scanning for failures...
172
+ -> Logged 93 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
173
+ -> Logged 23 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
174
+ Epoch [38/100], Train Loss: 0.0191, Train Acc: 99.25%, | Val Loss: 0.0081, Val Acc: 99.85%
175
+ -> High accuracy detected. Scanning for failures...
176
+ -> Logged 110 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
177
+ -> Logged 30 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
178
+ Epoch [39/100], Train Loss: 0.0182, Train Acc: 99.31%, | Val Loss: 0.0076, Val Acc: 99.89%
179
+ -> High accuracy detected. Scanning for failures...
180
+ -> Logged 74 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
181
+ -> Logged 22 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
182
+
183
+ """
184
+ # REFACTORING MISSION:
185
+ # This script objective is to perform
186
+ # automated failure analysis. When training or validation accuracy surpasses
187
+ # a 99% threshold, the script will automatically log the specific samples
188
+ # that the model failed on. These failures are appended to a CSV file for
189
+ # later inspection, which is invaluable for creating targeted test sets or
190
+ # improving the training data.
191
+
192
+ import torch
193
+ import torch.nn as nn
194
+ from torch.utils.data import TensorDataset, DataLoader
195
+ import random
196
+ import numpy as np
197
+ import zipfile
198
+ import os
199
+ import sys
200
+
201
+ # ==============================================================================
202
+ # Part 0: Configuration
203
+ # ==============================================================================
204
+ class Config:
205
+ # --- Data ---
206
+ num_samples = 100000 # Increased dataset size for more robust training
207
+ train_split = 0.8
208
+
209
+ # --- Model Architecture ---
210
+ input_size = 8
211
+ hidden_size_1 = 128
212
+ hidden_size_2 = 64
213
+ output_size = 1
214
+
215
+ # --- Training ---
216
+ learning_rate = 1e-4 # Slightly lower LR for finer tuning
217
+ batch_size = 256
218
+ epochs = 100 # Reduced epochs to 20 as convergence should be faster with more data
219
+ weight_decay = 1e-4
220
+
221
+ # --- NEW: Failure Analysis ---
222
+ # The accuracy threshold to trigger logging of failed samples.
223
+ failure_log_threshold = 99.0
224
+ # The name of the script, used for the output CSV file.
225
+ script_name = "GreaterThan_MLP_V1.1_FailureAnalysis"
226
+ failure_log_filename = f"{script_name}_failed_samples.csv"
227
+
228
+ # --- Device ---
229
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
230
+ print(f"Using device: {device}")
231
+
232
+ config = Config()
233
+
234
+ # For reproducibility
235
+ torch.manual_seed(1337)
236
+ random.seed(1337)
237
+ np.random.seed(1337)
238
+
239
+ # ==============================================================================
240
+ # Part 1: Colab Utility & Data Generation (Unchanged from V1.0)
241
+ # ==============================================================================
242
+ def is_in_colab():
243
+ """Checks if the script is running in a Google Colab environment."""
244
+ try:
245
+ import google.colab
246
+ return True
247
+ except ImportError:
248
+ return False
249
+
250
+ def generate_mlp_data(num_samples):
251
+ """Generates synthetic data for the MLP."""
252
+ print(f"Generating {num_samples} data points...")
253
+ features, labels = [], []
254
+ for _ in range(num_samples):
255
+ a = round(random.uniform(0, 99.99), 2)
256
+ b = round(random.uniform(0, 99.99), 2)
257
+ while a == b:
258
+ b = round(random.uniform(0, 99.99), 2)
259
+ a_str, b_str = f"{a:05.2f}", f"{b:05.2f}"
260
+ a_digits, b_digits = [int(d) for d in a_str if d.isdigit()], [int(d) for d in b_str if d.isdigit()]
261
+ features.append(a_digits + b_digits)
262
+ labels.append(0 if a > b else 1)
263
+ X = torch.tensor(features, dtype=torch.float32)
264
+ y = torch.tensor(labels, dtype=torch.float32).unsqueeze(1)
265
+ print("Data generation complete.")
266
+ return X, y
267
+
268
+ # ==============================================================================
269
+ # Part 2: Model Architecture (Unchanged from V1.0)
270
+ # ==============================================================================
271
+ class SimpleMLP(nn.Module):
272
+ def __init__(self, input_size, hidden_size_1, hidden_size_2, output_size):
273
+ super().__init__()
274
+ self.net = nn.Sequential(
275
+ nn.Linear(input_size, hidden_size_1), nn.ReLU(), nn.Dropout(0.2),
276
+ nn.Linear(hidden_size_1, hidden_size_2), nn.ReLU(), nn.Dropout(0.2),
277
+ nn.Linear(hidden_size_2, output_size)
278
+ )
279
+ def forward(self, x):
280
+ return self.net(x)
281
+
282
+ # ==============================================================================
283
+ # Part 3: MODIFIED Training and Evaluation Loop
284
+ # ==============================================================================
285
+
286
+ def log_failures(model, loader, split, epoch, filename, device):
287
+ """
288
+ NEW: Iterates through a data loader, finds incorrect predictions,
289
+ and appends them to the specified CSV log file.
290
+ """
291
+ model.eval()
292
+ failures_found = 0
293
+ with torch.no_grad():
294
+ for inputs, labels in loader:
295
+ inputs, labels = inputs.to(device), labels.to(device)
296
+ outputs = model(inputs)
297
+ predicted = torch.round(torch.sigmoid(outputs))
298
+ mismatch_indices = (predicted != labels).squeeze()
299
+
300
+ if mismatch_indices.any():
301
+ failed_inputs = inputs[mismatch_indices]
302
+ failed_true_labels = labels[mismatch_indices]
303
+ failed_pred_labels = predicted[mismatch_indices]
304
+
305
+ with open(filename, 'a') as f:
306
+ for i in range(failed_inputs.size(0)):
307
+ # Format the input vector back into a readable string
308
+ input_vec_int = failed_inputs[i].cpu().numpy().astype(int)
309
+ num1_str = f"{input_vec_int[0]}{input_vec_int[1]}.{input_vec_int[2]}{input_vec_int[3]}"
310
+ num2_str = f"{input_vec_int[4]}{input_vec_int[5]}.{input_vec_int[6]}{input_vec_int[7]}"
311
+
312
+ true_label = int(failed_true_labels[i].item())
313
+ pred_label = int(failed_pred_labels[i].item())
314
+
315
+ f.write(f"{epoch},{split},{num1_str},{num2_str},{true_label},{pred_label}\n")
316
+ failures_found += 1
317
+ if failures_found > 0:
318
+ print(f" -> Logged {failures_found} failures for '{split}' split to {filename}")
319
+
320
+ def train_model(model, train_loader, val_loader, optimizer, criterion, epochs, config):
321
+ """The main training loop, now with failure logging."""
322
+ print("\n--- Starting Training ---")
323
+
324
+ # NEW: Initialize the failure log file with a header
325
+ with open(config.failure_log_filename, 'w') as f:
326
+ f.write("epoch,split,num1,num2,true_label(0:L>R;1:R>L),predicted_label\n")
327
+
328
+ for epoch in range(epochs):
329
+ model.train()
330
+ train_loss, train_correct, train_total = 0, 0, 0
331
+ for inputs, labels in train_loader:
332
+ inputs, labels = inputs.to(config.device), labels.to(config.device)
333
+ outputs = model(inputs)
334
+ loss = criterion(outputs, labels)
335
+ optimizer.zero_grad()
336
+ loss.backward()
337
+ optimizer.step()
338
+ train_loss += loss.item()
339
+ predicted = torch.round(torch.sigmoid(outputs))
340
+ train_total += labels.size(0)
341
+ train_correct += (predicted == labels).sum().item()
342
+
343
+ train_avg_loss = train_loss / len(train_loader)
344
+ train_accuracy = 100 * train_correct / train_total
345
+
346
+ model.eval()
347
+ val_loss, val_correct, val_total = 0, 0, 0
348
+ with torch.no_grad():
349
+ for inputs, labels in val_loader:
350
+ inputs, labels = inputs.to(config.device), labels.to(config.device)
351
+ outputs = model(inputs)
352
+ val_loss += criterion(outputs, labels).item()
353
+ predicted = torch.round(torch.sigmoid(outputs))
354
+ val_total += labels.size(0)
355
+ val_correct += (predicted == labels).sum().item()
356
+
357
+ val_avg_loss = val_loss / len(val_loader)
358
+ val_accuracy = 100 * val_correct / val_total
359
+
360
+ print(f"Epoch [{epoch+1}/{epochs}], "
361
+ f"Train Loss: {train_avg_loss:.4f}, Train Acc: {train_accuracy:.2f}%, | "
362
+ f"Val Loss: {val_avg_loss:.4f}, Val Acc: {val_accuracy:.2f}%")
363
+
364
+ # --- NEW: Conditional Failure Logging ---
365
+ if train_accuracy > config.failure_log_threshold or val_accuracy > config.failure_log_threshold:
366
+ print(f" -> High accuracy detected. Scanning for failures...")
367
+ # Re-iterate over loaders to find and log the specific failures for this epoch
368
+ log_failures(model, train_loader, 'train', epoch + 1, config.failure_log_filename, config.device)
369
+ log_failures(model, val_loader, 'val', epoch + 1, config.failure_log_filename, config.device)
370
+
371
+ print("--- Training Finished ---")
372
+
373
+ # ==============================================================================
374
+ # Part 4: MODIFIED Final Test Suite
375
+ # ==============================================================================
376
+ def run_final_tests(model, config):
377
+ """Runs the trained model against a suite of hardcoded test cases and logs failures."""
378
+ print("\n--- Running Final Test Suite ---")
379
+
380
+ test_cases = [
381
+ ("Simple Greater", 10.00, 9.21), ("Simple Lesser", 5.50, 50.50),
382
+ ("Decimal Greater", 54.13, 54.12), ("Decimal Lesser", 99.98, 99.99),
383
+ ("Edge Case: Large Difference", 0.01, 99.99), ("Edge Case: Zero", 0.00, 5.00),
384
+ ("Tricky: Same Integer Part", 25.80, 25.79), ("Tricky: Crossover", 49.99, 50.00),
385
+ ]
386
+
387
+ results_log = "--- MLP Test Suite Results ---\n\n"
388
+ correct_tests = 0
389
+
390
+ model.eval()
391
+ with torch.no_grad():
392
+ for description, a, b in test_cases:
393
+ a_str, b_str = f"{a:05.2f}", f"{b:05.2f}"
394
+ a_digits, b_digits = [int(d) for d in a_str if d.isdigit()], [int(d) for d in b_str if d.isdigit()]
395
+ feature_vector = torch.tensor(a_digits + b_digits, dtype=torch.float32).to(config.device)
396
+
397
+ output = model(feature_vector)
398
+ predicted_class = 1 if torch.sigmoid(output).item() > 0.5 else 0
399
+ ground_truth_class = 0 if a > b else 1
400
+
401
+ result = "CORRECT"
402
+ if predicted_class != ground_truth_class:
403
+ result = "INCORRECT"
404
+ # --- NEW: Log failure to CSV ---
405
+ with open(config.failure_log_filename, 'a') as f:
406
+ f.write(f"final_test,{description.replace(',',';')},{a_str},{b_str},{ground_truth_class},{predicted_class}\n")
407
+ else:
408
+ correct_tests += 1
409
+
410
+ predicted_winner = "Left" if predicted_class == 0 else "Right"
411
+ log_line = (f"Test: '{description}' | {a_str} vs {b_str}\n"
412
+ f" -> Model says: {predicted_winner} is greater\n"
413
+ f" -> Result: {result}\n" + "-"*30 + "\n")
414
+ print(log_line)
415
+ results_log += log_line
416
+
417
+ final_accuracy = 100 * correct_tests / len(test_cases)
418
+ summary = f"\nFinal Test Accuracy: {final_accuracy:.2f}% ({correct_tests}/{len(test_cases)} correct)\n"
419
+ print(summary)
420
+ results_log += summary
421
+
422
+ return results_log
423
+
424
+ # ==============================================================================
425
+ # Main Execution Block
426
+ # ==============================================================================
427
+ if __name__ == '__main__':
428
+ X, y = generate_mlp_data(config.num_samples)
429
+
430
+ dataset = TensorDataset(X, y)
431
+ train_size = int(config.train_split * len(dataset))
432
+ val_size = len(dataset) - train_size
433
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
434
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
435
+ val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
436
+
437
+ model = SimpleMLP(config.input_size, config.hidden_size_1, config.hidden_size_2, config.output_size).to(config.device)
438
+ print(f"\nModel initialized with {sum(p.numel() for p in model.parameters())} parameters.")
439
+
440
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
441
+ criterion = nn.BCEWithLogitsLoss()
442
+
443
+ train_model(model, train_loader, val_loader, optimizer, criterion, config.epochs, config)
444
+
445
+ test_results = run_final_tests(model, config)
446
+
447
+ if is_in_colab():
448
+ print(f"\nDetected Google Colab environment. Zipping and downloading results...")
449
+ results_filename = f"{config.script_name}_test_summary.txt"
450
+ with open(results_filename, "w") as f:
451
+ f.write(test_results)
452
+
453
+ zip_filename = f"{config.script_name}_outputs.zip"
454
+ with zipfile.ZipFile(zip_filename, 'w') as zipf:
455
+ zipf.write(results_filename)
456
+ # Also include the new failure log in the zip file
457
+ if os.path.exists(config.failure_log_filename):
458
+ zipf.write(config.failure_log_filename)
459
+
460
+ try:
461
+ from google.colab import files
462
+ files.download(zip_filename)
463
+ print(f"Downloaded {zip_filename} successfully.")
464
+ except Exception as e:
465
+ print(f"Could not initiate download. Error: {e}")
466
+ else:
467
+ print(f"\nNot running in Colab. Test results printed above. Failures logged to {config.failure_log_filename}")