File size: 28,327 Bytes
25652b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e23717
 
 
 
 
 
 
25652b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
# GreaterThan_MLP_V1.1_with_FailuresAnalysis.py
"""
The objective of GreaterThan_MLP_V1.0.py is to establish a fundamental performance baseline 
for a numerical comparison task using a deliberately simple Multi-Layer Perceptron (MLP). 
It avoids all natural language processing techniques by treating the problem as a pure binary classification 
on a fixed-size vector. The dataset consists of synthetically generated pairs of 
two-digit decimal numbers (e.g., 10.00 and 09.21), 
which are deconstructed and flattened into an 8-dimensional feature vector of their raw digits 
([1, 0, 0, 0,
0, 9, 2, 1]). 
The model is then trained to predict a single binary label (0 for left > right, 1 for right > left), 
directly testing the MLP's capability to learn the hierarchical rules of numerical magnitude 
from the positional values of the input digits alone.

The MLP model's task is to learn the rules of numerical magnitude from raw digits alone, 
treating the problem as a simple binary classification task. 
It's designed for maximum clarity and serves as a fundamental baseline for this reasoning problem.
The plan is clear: a simple MLP for binary classification. 
The 8-dimensional input vector, constructed from the two 4-digit numbers, will be the focus. 
The output will cleanly indicate which number is greater. Using on-the-fly data generation.
The generate_mlp_data function produces the correct 8-dimensional input vectors and binary labels. 
GreaterThan_MLP_V1.0.py presents a basic numerical comparison challenge using a rudimentary MLP as a baseline. 
The core approach hinges on framing the task as a binary classification problem on a fixed-length feature vector. 
Pairs of decimal numbers are converted into an 8-dimensional array of their digit values; 
for instance, 10.00 and 09.21 are transformed to [1, 0, 0, 0, 0, 9, 2, 1]. 
The model's training focuses on predicting whether one number is greater than another through a single binary label.

    test_cases = [
        ("Simple Greater", 10.00, 9.21), ("Simple Lesser", 5.50, 50.50),
        ("Decimal Greater", 54.13, 54.12), ("Decimal Lesser", 99.98, 99.99),
        ("Edge Case: Large Difference", 0.01, 99.99), ("Edge Case: Zero", 0.00, 5.00),
        ("Tricky: Same Integer Part", 25.80, 25.79), ("Tricky: Crossover", 49.99, 50.00),
    ]

The MLP baseline model performs remarkably well, achieving over 99.9% accuracy in deciding "GreaterThan". 
This indicates that the underlying logic of numerical comparison can be learned from raw digits by a simple neural network, 
provided the input is structured as a fixed-size vector.
However, even with high accuracy, failures still occur. Understanding why and on what data the model fails 
is the next critical step in ML engineering. This is how we discover dataset biases, edge cases, and architectural weaknesses.
Here is the modified script, GreaterThan_MLP_V1.1_with_FailuresAnalysis.py 
It incorporates to automatically detect and log failures to a CSV file when accuracy is high, 
creating a valuable dataset artifact for future analysis and the development of more robust models.

Here is the ouput of the first demonstration run in colab:
Model initialized with 9473 parameters.

--- Starting Training ---
Epoch [1/100], Train Loss: 0.4015, Train Acc: 82.67%, | Val Loss: 0.1690, Val Acc: 97.03%
Epoch [2/100], Train Loss: 0.1743, Train Acc: 92.94%, | Val Loss: 0.0974, Val Acc: 98.04%
Epoch [3/100], Train Loss: 0.1300, Train Acc: 94.54%, | Val Loss: 0.0741, Val Acc: 98.61%
Epoch [4/100], Train Loss: 0.1112, Train Acc: 95.20%, | Val Loss: 0.0618, Val Acc: 98.96%
Epoch [5/100], Train Loss: 0.1019, Train Acc: 95.61%, | Val Loss: 0.0565, Val Acc: 98.79%
Epoch [6/100], Train Loss: 0.0926, Train Acc: 96.04%, | Val Loss: 0.0498, Val Acc: 99.10%
  -> High accuracy detected. Scanning for failures...
    -> Logged 607 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 180 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [7/100], Train Loss: 0.0857, Train Acc: 96.33%, | Val Loss: 0.0456, Val Acc: 99.19%
  -> High accuracy detected. Scanning for failures...
    -> Logged 562 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 161 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [8/100], Train Loss: 0.0827, Train Acc: 96.47%, | Val Loss: 0.0430, Val Acc: 99.14%
  -> High accuracy detected. Scanning for failures...
    -> Logged 538 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 171 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [9/100], Train Loss: 0.0767, Train Acc: 96.73%, | Val Loss: 0.0398, Val Acc: 99.33%
  -> High accuracy detected. Scanning for failures...
    -> Logged 462 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 133 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [10/100], Train Loss: 0.0727, Train Acc: 96.87%, | Val Loss: 0.0376, Val Acc: 99.33%
  -> High accuracy detected. Scanning for failures...
    -> Logged 457 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 134 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [11/100], Train Loss: 0.0692, Train Acc: 97.04%, | Val Loss: 0.0380, Val Acc: 99.06%
  -> High accuracy detected. Scanning for failures...
    -> Logged 703 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 189 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [12/100], Train Loss: 0.0665, Train Acc: 97.17%, | Val Loss: 0.0333, Val Acc: 99.42%
  -> High accuracy detected. Scanning for failures...
    -> Logged 365 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 117 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [13/100], Train Loss: 0.0619, Train Acc: 97.36%, | Val Loss: 0.0316, Val Acc: 99.42%
  -> High accuracy detected. Scanning for failures...
    -> Logged 396 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 115 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [14/100], Train Loss: 0.0599, Train Acc: 97.46%, | Val Loss: 0.0301, Val Acc: 99.41%
  -> High accuracy detected. Scanning for failures...
    -> Logged 397 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 119 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [15/100], Train Loss: 0.0568, Train Acc: 97.63%, | Val Loss: 0.0282, Val Acc: 99.47%
  -> High accuracy detected. Scanning for failures...
    -> Logged 359 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 107 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [16/100], Train Loss: 0.0550, Train Acc: 97.72%, | Val Loss: 0.0266, Val Acc: 99.53%
  -> High accuracy detected. Scanning for failures...
    -> Logged 331 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 94 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [17/100], Train Loss: 0.0524, Train Acc: 97.80%, | Val Loss: 0.0256, Val Acc: 99.55%
  -> High accuracy detected. Scanning for failures...
    -> Logged 321 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 91 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [18/100], Train Loss: 0.0504, Train Acc: 97.93%, | Val Loss: 0.0240, Val Acc: 99.56%
  -> High accuracy detected. Scanning for failures...
    -> Logged 290 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 87 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [19/100], Train Loss: 0.0472, Train Acc: 98.04%, | Val Loss: 0.0228, Val Acc: 99.53%
  -> High accuracy detected. Scanning for failures...
    -> Logged 288 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 93 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [20/100], Train Loss: 0.0447, Train Acc: 98.16%, | Val Loss: 0.0216, Val Acc: 99.61%
  -> High accuracy detected. Scanning for failures...
    -> Logged 289 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 78 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [21/100], Train Loss: 0.0445, Train Acc: 98.12%, | Val Loss: 0.0201, Val Acc: 99.69%
  -> High accuracy detected. Scanning for failures...
    -> Logged 240 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 63 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [22/100], Train Loss: 0.0412, Train Acc: 98.29%, | Val Loss: 0.0191, Val Acc: 99.65%
  -> High accuracy detected. Scanning for failures...
    -> Logged 227 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 70 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [23/100], Train Loss: 0.0395, Train Acc: 98.35%, | Val Loss: 0.0181, Val Acc: 99.65%
  -> High accuracy detected. Scanning for failures...
    -> Logged 236 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 70 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [24/100], Train Loss: 0.0373, Train Acc: 98.48%, | Val Loss: 0.0170, Val Acc: 99.71%
  -> High accuracy detected. Scanning for failures...
    -> Logged 209 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 58 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [25/100], Train Loss: 0.0362, Train Acc: 98.53%, | Val Loss: 0.0164, Val Acc: 99.68%
  -> High accuracy detected. Scanning for failures...
    -> Logged 222 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 64 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [26/100], Train Loss: 0.0345, Train Acc: 98.61%, | Val Loss: 0.0153, Val Acc: 99.73%
  -> High accuracy detected. Scanning for failures...
    -> Logged 199 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 53 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [27/100], Train Loss: 0.0317, Train Acc: 98.74%, | Val Loss: 0.0149, Val Acc: 99.61%
  -> High accuracy detected. Scanning for failures...
    -> Logged 253 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 78 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [28/100], Train Loss: 0.0302, Train Acc: 98.80%, | Val Loss: 0.0134, Val Acc: 99.80%
  -> High accuracy detected. Scanning for failures...
    -> Logged 162 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 40 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [29/100], Train Loss: 0.0299, Train Acc: 98.80%, | Val Loss: 0.0127, Val Acc: 99.77%
  -> High accuracy detected. Scanning for failures...
    -> Logged 163 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 46 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [30/100], Train Loss: 0.0261, Train Acc: 98.98%, | Val Loss: 0.0125, Val Acc: 99.68%
  -> High accuracy detected. Scanning for failures...
    -> Logged 240 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 64 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [31/100], Train Loss: 0.0251, Train Acc: 99.05%, | Val Loss: 0.0110, Val Acc: 99.84%
  -> High accuracy detected. Scanning for failures...
    -> Logged 135 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 32 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [32/100], Train Loss: 0.0246, Train Acc: 99.01%, | Val Loss: 0.0108, Val Acc: 99.78%
  -> High accuracy detected. Scanning for failures...
    -> Logged 167 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 43 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [33/100], Train Loss: 0.0237, Train Acc: 99.07%, | Val Loss: 0.0103, Val Acc: 99.83%
  -> High accuracy detected. Scanning for failures...
    -> Logged 121 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 34 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [34/100], Train Loss: 0.0224, Train Acc: 99.14%, | Val Loss: 0.0096, Val Acc: 99.86%
  -> High accuracy detected. Scanning for failures...
    -> Logged 127 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 29 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [35/100], Train Loss: 0.0220, Train Acc: 99.15%, | Val Loss: 0.0092, Val Acc: 99.89%
  -> High accuracy detected. Scanning for failures...
    -> Logged 100 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 23 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [36/100], Train Loss: 0.0204, Train Acc: 99.22%, | Val Loss: 0.0090, Val Acc: 99.83%
  -> High accuracy detected. Scanning for failures...
    -> Logged 126 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 34 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [37/100], Train Loss: 0.0194, Train Acc: 99.25%, | Val Loss: 0.0083, Val Acc: 99.89%
  -> High accuracy detected. Scanning for failures...
    -> Logged 93 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 23 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [38/100], Train Loss: 0.0191, Train Acc: 99.25%, | Val Loss: 0.0081, Val Acc: 99.85%
  -> High accuracy detected. Scanning for failures...
    -> Logged 110 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 30 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
Epoch [39/100], Train Loss: 0.0182, Train Acc: 99.31%, | Val Loss: 0.0076, Val Acc: 99.89%
  -> High accuracy detected. Scanning for failures...
    -> Logged 74 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
    -> Logged 22 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv

"""
# REFACTORING MISSION:
# This script objective is to perform
# automated failure analysis. When training or validation accuracy surpasses
# a 99% threshold, the script will automatically log the specific samples
# that the model failed on. These failures are appended to a CSV file for
# later inspection, which is invaluable for creating targeted test sets or
# improving the training data.

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import random
import numpy as np
import zipfile
import os
import sys

# ==============================================================================
# Part 0: Configuration
# ==============================================================================
class Config:
    # --- Data ---
    num_samples = 100000 # Increased dataset size for more robust training
    train_split = 0.8

    # --- Model Architecture ---
    input_size = 8
    hidden_size_1 = 128
    hidden_size_2 = 64
    output_size = 1

    # --- Training ---
    learning_rate = 1e-4 # Slightly lower LR for finer tuning
    batch_size = 256
    epochs = 100 # Reduced epochs to 20 as convergence should be faster with more data
    weight_decay = 1e-4

    # --- NEW: Failure Analysis ---
    # The accuracy threshold to trigger logging of failed samples.
    failure_log_threshold = 99.0
    # The name of the script, used for the output CSV file.
    script_name = "GreaterThan_MLP_V1.1_FailureAnalysis"
    failure_log_filename = f"{script_name}_failed_samples.csv"

    # --- Device ---
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

config = Config()

# For reproducibility
torch.manual_seed(1337)
random.seed(1337)
np.random.seed(1337)

# ==============================================================================
# Part 1: Colab Utility & Data Generation (Unchanged from V1.0)
# ==============================================================================
def is_in_colab():
    """Checks if the script is running in a Google Colab environment."""
    try:
        import google.colab
        return True
    except ImportError:
        return False

def generate_mlp_data(num_samples):
    """Generates synthetic data for the MLP."""
    print(f"Generating {num_samples} data points...")
    features, labels = [], []
    for _ in range(num_samples):
        a = round(random.uniform(0, 99.99), 2)
        b = round(random.uniform(0, 99.99), 2)
        while a == b:
            b = round(random.uniform(0, 99.99), 2)
        a_str, b_str = f"{a:05.2f}", f"{b:05.2f}"
        a_digits, b_digits = [int(d) for d in a_str if d.isdigit()], [int(d) for d in b_str if d.isdigit()]
        features.append(a_digits + b_digits)
        labels.append(0 if a > b else 1)
    X = torch.tensor(features, dtype=torch.float32)
    y = torch.tensor(labels, dtype=torch.float32).unsqueeze(1)
    print("Data generation complete.")
    return X, y

# ==============================================================================
# Part 2: Model Architecture (Unchanged from V1.0)
# ==============================================================================
class SimpleMLP(nn.Module):
    def __init__(self, input_size, hidden_size_1, hidden_size_2, output_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size_1), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(hidden_size_1, hidden_size_2), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(hidden_size_2, output_size)
        )
    def forward(self, x):
        return self.net(x)

# ==============================================================================
# Part 3: MODIFIED Training and Evaluation Loop
# ==============================================================================

def log_failures(model, loader, split, epoch, filename, device):
    """
    NEW: Iterates through a data loader, finds incorrect predictions,
    and appends them to the specified CSV log file.
    """
    model.eval()
    failures_found = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            predicted = torch.round(torch.sigmoid(outputs))
            mismatch_indices = (predicted != labels).squeeze()
            
            if mismatch_indices.any():
                failed_inputs = inputs[mismatch_indices]
                failed_true_labels = labels[mismatch_indices]
                failed_pred_labels = predicted[mismatch_indices]
                
                with open(filename, 'a') as f:
                    for i in range(failed_inputs.size(0)):
                        # Format the input vector back into a readable string
                        input_vec_int = failed_inputs[i].cpu().numpy().astype(int)
                        num1_str = f"{input_vec_int[0]}{input_vec_int[1]}.{input_vec_int[2]}{input_vec_int[3]}"
                        num2_str = f"{input_vec_int[4]}{input_vec_int[5]}.{input_vec_int[6]}{input_vec_int[7]}"
                        
                        true_label = int(failed_true_labels[i].item())
                        pred_label = int(failed_pred_labels[i].item())
                        
                        f.write(f"{epoch},{split},{num1_str},{num2_str},{true_label},{pred_label}\n")
                        failures_found += 1
    if failures_found > 0:
        print(f"    -> Logged {failures_found} failures for '{split}' split to {filename}")

def train_model(model, train_loader, val_loader, optimizer, criterion, epochs, config):
    """The main training loop, now with failure logging."""
    print("\n--- Starting Training ---")
    
    # NEW: Initialize the failure log file with a header
    with open(config.failure_log_filename, 'w') as f:
        f.write("epoch,split,num1,num2,true_label(0:L>R;1:R>L),predicted_label\n")

    for epoch in range(epochs):
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(config.device), labels.to(config.device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            predicted = torch.round(torch.sigmoid(outputs))
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_avg_loss = train_loss / len(train_loader)
        train_accuracy = 100 * train_correct / train_total
        
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(config.device), labels.to(config.device)
                outputs = model(inputs)
                val_loss += criterion(outputs, labels).item()
                predicted = torch.round(torch.sigmoid(outputs))
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_avg_loss = val_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        
        print(f"Epoch [{epoch+1}/{epochs}], "
              f"Train Loss: {train_avg_loss:.4f}, Train Acc: {train_accuracy:.2f}%, | "
              f"Val Loss: {val_avg_loss:.4f}, Val Acc: {val_accuracy:.2f}%")
              
        # --- NEW: Conditional Failure Logging ---
        if train_accuracy > config.failure_log_threshold or val_accuracy > config.failure_log_threshold:
            print(f"  -> High accuracy detected. Scanning for failures...")
            # Re-iterate over loaders to find and log the specific failures for this epoch
            log_failures(model, train_loader, 'train', epoch + 1, config.failure_log_filename, config.device)
            log_failures(model, val_loader, 'val', epoch + 1, config.failure_log_filename, config.device)

    print("--- Training Finished ---")

# ==============================================================================
# Part 4: MODIFIED Final Test Suite
# ==============================================================================
def run_final_tests(model, config):
    """Runs the trained model against a suite of hardcoded test cases and logs failures."""
    print("\n--- Running Final Test Suite ---")
    
    test_cases = [
        ("Simple Greater", 10.00, 9.21), ("Simple Lesser", 5.50, 50.50),
        ("Decimal Greater", 54.13, 54.12), ("Decimal Lesser", 99.98, 99.99),
        ("Edge Case: Large Difference", 0.01, 99.99), ("Edge Case: Zero", 0.00, 5.00),
        ("Tricky: Same Integer Part", 25.80, 25.79), ("Tricky: Crossover", 49.99, 50.00),
    ]
    
    results_log = "--- MLP Test Suite Results ---\n\n"
    correct_tests = 0
    
    model.eval()
    with torch.no_grad():
        for description, a, b in test_cases:
            a_str, b_str = f"{a:05.2f}", f"{b:05.2f}"
            a_digits, b_digits = [int(d) for d in a_str if d.isdigit()], [int(d) for d in b_str if d.isdigit()]
            feature_vector = torch.tensor(a_digits + b_digits, dtype=torch.float32).to(config.device)
            
            output = model(feature_vector)
            predicted_class = 1 if torch.sigmoid(output).item() > 0.5 else 0
            ground_truth_class = 0 if a > b else 1
            
            result = "CORRECT"
            if predicted_class != ground_truth_class:
                result = "INCORRECT"
                # --- NEW: Log failure to CSV ---
                with open(config.failure_log_filename, 'a') as f:
                    f.write(f"final_test,{description.replace(',',';')},{a_str},{b_str},{ground_truth_class},{predicted_class}\n")
            else:
                correct_tests += 1
            
            predicted_winner = "Left" if predicted_class == 0 else "Right"
            log_line = (f"Test: '{description}' | {a_str} vs {b_str}\n"
                        f"  -> Model says: {predicted_winner} is greater\n"
                        f"  -> Result: {result}\n" + "-"*30 + "\n")
            print(log_line)
            results_log += log_line

    final_accuracy = 100 * correct_tests / len(test_cases)
    summary = f"\nFinal Test Accuracy: {final_accuracy:.2f}% ({correct_tests}/{len(test_cases)} correct)\n"
    print(summary)
    results_log += summary
    
    return results_log

# ==============================================================================
# Main Execution Block
# ==============================================================================
if __name__ == '__main__':
    X, y = generate_mlp_data(config.num_samples)
    
    dataset = TensorDataset(X, y)
    train_size = int(config.train_split * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    
    model = SimpleMLP(config.input_size, config.hidden_size_1, config.hidden_size_2, config.output_size).to(config.device)
    print(f"\nModel initialized with {sum(p.numel() for p in model.parameters())} parameters.")
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    criterion = nn.BCEWithLogitsLoss()
    
    train_model(model, train_loader, val_loader, optimizer, criterion, config.epochs, config)
    
    test_results = run_final_tests(model, config)
    
    if is_in_colab():
        print(f"\nDetected Google Colab environment. Zipping and downloading results...")
        results_filename = f"{config.script_name}_test_summary.txt"
        with open(results_filename, "w") as f:
            f.write(test_results)
            
        zip_filename = f"{config.script_name}_outputs.zip"
        with zipfile.ZipFile(zip_filename, 'w') as zipf:
            zipf.write(results_filename)
            # Also include the new failure log in the zip file
            if os.path.exists(config.failure_log_filename):
                zipf.write(config.failure_log_filename)
            
        try:
            from google.colab import files
            files.download(zip_filename)
            print(f"Downloaded {zip_filename} successfully.")
        except Exception as e:
            print(f"Could not initiate download. Error: {e}")
    else:
        print(f"\nNot running in Colab. Test results printed above. Failures logged to {config.failure_log_filename}")