sigmoidneuron123 commited on
Commit
082c3a5
·
verified ·
1 Parent(s): 726fcdb

Create AppleAI-converter-colab.py

Browse files
Files changed (1) hide show
  1. AppleAI-converter-colab.py +74 -0
AppleAI-converter-colab.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install mlx')
3
+ import mlx as mx
4
+ import mlx.nn as mx_nn
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ device = torch.device('cuda')
10
+
11
+ CONFIG = {
12
+ "model_path": "NeoChess/chessy_model.pth",
13
+ "backup_model_path": "NeoChess/chessy_modelt-1.pth",
14
+ }
15
+
16
+ class NN1(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.embedding = nn.Embedding(13, 64)
20
+ self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=16)
21
+ self.neu = 512
22
+ self.neurons = nn.Sequential(
23
+ nn.Linear(4096, self.neu),
24
+ nn.ReLU(),
25
+ nn.Linear(self.neu, self.neu),
26
+ nn.ReLU(),
27
+ nn.Linear(self.neu, self.neu),
28
+ nn.ReLU(),
29
+ nn.Linear(self.neu, self.neu),
30
+ nn.ReLU(),
31
+ nn.Linear(self.neu, self.neu),
32
+ nn.ReLU(),
33
+ nn.Linear(self.neu, self.neu),
34
+ nn.ReLU(),
35
+ nn.Linear(self.neu, self.neu),
36
+ nn.ReLU(),
37
+ nn.Linear(self.neu, self.neu),
38
+ nn.ReLU(),
39
+ nn.Linear(self.neu, self.neu),
40
+ nn.ReLU(),
41
+ nn.Linear(self.neu, self.neu),
42
+ nn.ReLU(),
43
+ nn.Linear(self.neu, self.neu),
44
+ nn.ReLU(),
45
+ nn.Linear(self.neu, self.neu),
46
+ nn.ReLU(),
47
+ nn.Linear(self.neu, self.neu),
48
+ nn.ReLU(),
49
+ nn.Linear(self.neu, 64),
50
+ nn.ReLU(),
51
+ nn.Linear(64, 4)
52
+ )
53
+
54
+ def forward(self, x):
55
+ x = self.embedding(x)
56
+ x = x.permute(1, 0, 2)
57
+ attn_output, _ = self.attention(x, x, x)
58
+ x = attn_output.permute(1, 0, 2).contiguous()
59
+ x = x.view(x.size(0), -1)
60
+ x = self.neurons(x)
61
+ return x
62
+
63
+ model = NN1().to(device)
64
+ try:
65
+ model.load_state_dict(torch.load(CONFIG['model_path'], map_location=device))
66
+ print(f"Loaded model from {CONFIG['model_path']}")
67
+ except FileNotFoundError:
68
+ try:
69
+ model.load_state_dict(torch.load(CONFIG["backup_model_path"], map_location=device))
70
+ print(f"Loaded backup model from {CONFIG['backup_model_path']}")
71
+ except FileNotFoundError:
72
+ print("No model file found, starting from scratch.")
73
+ weights = {k: v.detach().cpu().numpy() for k, v in model.state_dict().items()}
74
+ np.savez("chessy_model_mlx.npz", **weights)