HaileyStorm
commited on
Commit
•
f8519d1
1
Parent(s):
624e9a2
Upload 37 files
Browse files- chess-gpt-eval/gpt_query.py +265 -0
- chess-gpt-eval/llama_module.py +71 -0
- chess-gpt-eval/main.py +565 -0
- chess-gpt-eval/mamba.py +368 -0
- chess-gpt-eval/mamba/out/meta.pkl +3 -0
- chess-gpt-eval/mamba_lm.py +168 -0
- chess-gpt-eval/mamba_module.py +144 -0
- chess-gpt-eval/nanogpt/__pycache__/model.cpython-310.pyc +0 -0
- chess-gpt-eval/nanogpt/__pycache__/nanogpt_module.cpython-310.pyc +0 -0
- chess-gpt-eval/nanogpt/__pycache__/xformer.cpython-310.pyc +0 -0
- chess-gpt-eval/nanogpt/configurator.py +47 -0
- chess-gpt-eval/nanogpt/model.py +330 -0
- chess-gpt-eval/nanogpt/nanogpt_module.py +148 -0
- chess-gpt-eval/nanogpt/out/meta.pkl +3 -0
- chess-gpt-eval/nanogpt/out/view_ckpt.ipynb +61 -0
- chess-gpt-eval/openings.csv +0 -0
- chess-gpt-eval/pscan.py +226 -0
- chess-gpt-eval/requirements.txt +6 -0
- chess-gpt-eval/xformer.py +330 -0
- chess-mamba-vs-xformer/config/Mamba/11M.py +70 -0
- chess-mamba-vs-xformer/config/Mamba/250M.py +70 -0
- chess-mamba-vs-xformer/config/Mamba/29M.py +70 -0
- chess-mamba-vs-xformer/config/Mamba/50M.py +70 -0
- chess-mamba-vs-xformer/config/Mamba/6.6M.py +70 -0
- chess-mamba-vs-xformer/config/Xformer/11M.py +70 -0
- chess-mamba-vs-xformer/config/Xformer/250M.py +70 -0
- chess-mamba-vs-xformer/config/Xformer/29M.py +70 -0
- chess-mamba-vs-xformer/config/Xformer/50M.py +70 -0
- chess-mamba-vs-xformer/config/Xformer/6.6M.py +70 -0
- chess-mamba-vs-xformer/configurator.py +47 -0
- chess-mamba-vs-xformer/data/anneal/anneal.zip +3 -0
- chess-mamba-vs-xformer/mamba.py +368 -0
- chess-mamba-vs-xformer/mamba_lm.py +168 -0
- chess-mamba-vs-xformer/openings.csv +0 -0
- chess-mamba-vs-xformer/pscan.py +226 -0
- chess-mamba-vs-xformer/train_bygame.py +541 -0
- chess-mamba-vs-xformer/xformer.py +330 -0
chess-gpt-eval/gpt_query.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import tiktoken
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
# import replicate
|
7 |
+
|
8 |
+
# for hugging face inference endpoints for codellama
|
9 |
+
import requests
|
10 |
+
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
from tenacity import (
|
14 |
+
retry,
|
15 |
+
stop_after_attempt,
|
16 |
+
wait_random_exponential,
|
17 |
+
) # for exponential backoff
|
18 |
+
|
19 |
+
# system message is used in openai_request()
|
20 |
+
system_message = """Provide the next move in the chess game. Only provide the move, no move numbers."""
|
21 |
+
|
22 |
+
# dollars per 1k tokens, per openai.com/pricing
|
23 |
+
pricing_dict = {
|
24 |
+
"gpt-4": 0.03,
|
25 |
+
"gpt-4-0301": 0.03,
|
26 |
+
"gpt-4-0613": 0.03,
|
27 |
+
"gpt-3.5-turbo": 0.0015,
|
28 |
+
"gpt-3.5-turbo-0301": 0.0015,
|
29 |
+
"gpt-3.5-turbo-0613": 0.0015,
|
30 |
+
"gpt-3.5-turbo-16k": 0.003,
|
31 |
+
"babbage": 0.0005,
|
32 |
+
"gpt-3.5-turbo-instruct": 0.0015,
|
33 |
+
}
|
34 |
+
|
35 |
+
MAX_TOKENS = 10
|
36 |
+
|
37 |
+
completion_models = [
|
38 |
+
"gpt-3.5-turbo-instruct",
|
39 |
+
"babbage",
|
40 |
+
"davinci",
|
41 |
+
]
|
42 |
+
|
43 |
+
|
44 |
+
# tenacity is to handle anytime a request fails
|
45 |
+
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
46 |
+
def get_gpt_response(
|
47 |
+
prompt: str, model: str = "gpt-4", temperature: float = 0.0
|
48 |
+
) -> Optional[str]:
|
49 |
+
try:
|
50 |
+
messages = []
|
51 |
+
# system message is used in openai_request()
|
52 |
+
# system_message_dict = {
|
53 |
+
# "role": "system",
|
54 |
+
# "content": system_message,
|
55 |
+
# }
|
56 |
+
initial_message = {"role": "user", "content": prompt}
|
57 |
+
messages.append(initial_message)
|
58 |
+
|
59 |
+
record_messages(messages, model)
|
60 |
+
|
61 |
+
# num_tokens = count_all_tokens(model, messages)
|
62 |
+
# prompt_cost = get_prompt_cost(model, num_tokens)
|
63 |
+
# print("prompt cost in $:", prompt_cost)
|
64 |
+
|
65 |
+
if model in completion_models:
|
66 |
+
response = get_completions_response(model, messages, temperature)
|
67 |
+
elif model.startswith("gpt"):
|
68 |
+
response = openai_chat_completion_request(model, messages, temperature)
|
69 |
+
elif model.startswith("openrouter"):
|
70 |
+
response = openrouter_request(model, messages, temperature)
|
71 |
+
elif model.startswith("huggingface"):
|
72 |
+
response = hugging_face_request(model, messages, temperature)
|
73 |
+
elif model.startswith("replicate"):
|
74 |
+
response = replicate_request(model, messages, temperature)
|
75 |
+
else:
|
76 |
+
raise Exception("Invalid model name")
|
77 |
+
|
78 |
+
# response_cost = get_response_cost(model, count_tokens(model, response))
|
79 |
+
# print("response cost in $:", response_cost)
|
80 |
+
|
81 |
+
messages.append({"role": "assistant", "content": response})
|
82 |
+
record_messages(messages, model)
|
83 |
+
|
84 |
+
return response
|
85 |
+
except Exception as e:
|
86 |
+
print(f"Error while getting GPT response: {e}")
|
87 |
+
return None
|
88 |
+
|
89 |
+
|
90 |
+
def openai_chat_completion_request(
|
91 |
+
model: str, messages: list[dict], temperature: float
|
92 |
+
) -> str:
|
93 |
+
system_message_dict = {
|
94 |
+
"role": "system",
|
95 |
+
"content": system_message,
|
96 |
+
}
|
97 |
+
messages.append(system_message_dict)
|
98 |
+
completion = openai.ChatCompletion.create(
|
99 |
+
model=model,
|
100 |
+
temperature=temperature,
|
101 |
+
messages=messages,
|
102 |
+
)
|
103 |
+
response = completion.choices[0].message.content
|
104 |
+
return response
|
105 |
+
|
106 |
+
|
107 |
+
def openrouter_request(model: str, messages: list[dict], temperature: float) -> str:
|
108 |
+
if temperature == 0:
|
109 |
+
temperature = 0.001
|
110 |
+
|
111 |
+
with open("gpt_inputs/openrouter_api_key.txt", "r") as f:
|
112 |
+
openai.api_key = f.read().strip()
|
113 |
+
|
114 |
+
openai.api_base = "https://openrouter.ai/api/v1"
|
115 |
+
OPENROUTER_REFERRER = "https://github.com/adamkarvonen/nanoGPT"
|
116 |
+
|
117 |
+
model = model.replace("openrouter/", "")
|
118 |
+
|
119 |
+
completion = openai.ChatCompletion.create(
|
120 |
+
model=model,
|
121 |
+
headers={"HTTP-Referer": OPENROUTER_REFERRER},
|
122 |
+
messages=messages,
|
123 |
+
temperature=temperature,
|
124 |
+
max_tokens=MAX_TOKENS,
|
125 |
+
)
|
126 |
+
response = completion.choices[0].message.content
|
127 |
+
return response
|
128 |
+
|
129 |
+
|
130 |
+
def replicate_request(model: str, messages: list[dict], temperature: float) -> str:
|
131 |
+
if temperature == 0:
|
132 |
+
temperature = 0.001
|
133 |
+
|
134 |
+
with open("gpt_inputs/replicate_api_key.txt", "r") as f:
|
135 |
+
api_key = f.read().strip()
|
136 |
+
os.environ["REPLICATE_API_TOKEN"] = api_key
|
137 |
+
|
138 |
+
model = model.replace("replicate/", "")
|
139 |
+
|
140 |
+
messages = translate_to_string_input(messages)
|
141 |
+
|
142 |
+
output = replicate.run(
|
143 |
+
model,
|
144 |
+
input={
|
145 |
+
"prompt": messages,
|
146 |
+
"max_new_tokens": MAX_TOKENS,
|
147 |
+
"temperature": temperature,
|
148 |
+
},
|
149 |
+
)
|
150 |
+
|
151 |
+
# The meta/llama-2-7b model can stream output as it's running.
|
152 |
+
response = ""
|
153 |
+
# The predict method returns an iterator, and you can iterate over that output.
|
154 |
+
for item in output:
|
155 |
+
# https://replicate.com/meta/llama-2-7b/versions/527827021d8756c7ab79fde0abbfaac885c37a3ed5fe23c7465093f0878d55ef/api#output-schema
|
156 |
+
response += item
|
157 |
+
|
158 |
+
return response
|
159 |
+
|
160 |
+
|
161 |
+
def hugging_face_request(model: str, messages: list[dict], temperature: float) -> str:
|
162 |
+
def query(payload):
|
163 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
164 |
+
return response.json()
|
165 |
+
|
166 |
+
messages = translate_to_string_input(messages)
|
167 |
+
API_URL = "https://xxxxxxxx.us-east-1.aws.endpoints.huggingface.cloud"
|
168 |
+
headers = {
|
169 |
+
"Authorization": "Bearer xxxxx",
|
170 |
+
"Content-Type": "application/json",
|
171 |
+
}
|
172 |
+
|
173 |
+
if temperature == 0:
|
174 |
+
temperature = 0.001
|
175 |
+
|
176 |
+
output = query(
|
177 |
+
{
|
178 |
+
"inputs": messages,
|
179 |
+
"parameters": {"temperature": temperature, "max_new_tokens": MAX_TOKENS},
|
180 |
+
}
|
181 |
+
)
|
182 |
+
|
183 |
+
return output[0]["generated_text"]
|
184 |
+
|
185 |
+
|
186 |
+
def translate_to_string_input(
|
187 |
+
openai_messages: list[dict], roles_included: bool = False
|
188 |
+
):
|
189 |
+
# Translate from OpenAI's dict to a single string input
|
190 |
+
messages = []
|
191 |
+
for message in openai_messages:
|
192 |
+
if roles_included:
|
193 |
+
messages.append(message["role"] + ": ")
|
194 |
+
messages.append(message["content"])
|
195 |
+
if roles_included:
|
196 |
+
messages.append("assistant: ")
|
197 |
+
return "\n".join(messages)
|
198 |
+
|
199 |
+
|
200 |
+
# for gpt-3 models and instruct models
|
201 |
+
def get_completions_response(
|
202 |
+
model: str,
|
203 |
+
messages: list[dict] | str,
|
204 |
+
temperature: float,
|
205 |
+
max_tokens: int = MAX_TOKENS,
|
206 |
+
) -> str:
|
207 |
+
if not isinstance(messages, str):
|
208 |
+
prompt = translate_to_string_input(messages, roles_included=False)
|
209 |
+
else:
|
210 |
+
prompt = messages
|
211 |
+
|
212 |
+
completion = openai.Completion.create(
|
213 |
+
model=model, temperature=temperature, prompt=prompt, max_tokens=max_tokens
|
214 |
+
)
|
215 |
+
|
216 |
+
response = completion.choices[0].text
|
217 |
+
return response
|
218 |
+
|
219 |
+
|
220 |
+
def count_all_tokens(model: str, messages: list[dict[str, str]]) -> int:
|
221 |
+
total_tokens = 0
|
222 |
+
for message in messages:
|
223 |
+
total_tokens += count_tokens(model, message["content"])
|
224 |
+
return total_tokens
|
225 |
+
|
226 |
+
|
227 |
+
def count_tokens(model: str, prompt: str) -> int:
|
228 |
+
if "gpt" not in model:
|
229 |
+
model = "gpt-4"
|
230 |
+
|
231 |
+
encoding = tiktoken.encoding_for_model(model)
|
232 |
+
num_tokens = len(encoding.encode(prompt))
|
233 |
+
return num_tokens
|
234 |
+
|
235 |
+
|
236 |
+
def get_prompt_cost(model: str, num_tokens: int) -> float:
|
237 |
+
# good enough for quick evals
|
238 |
+
if model not in pricing_dict:
|
239 |
+
return num_tokens * 0.001 * pricing_dict["gpt-4"]
|
240 |
+
return num_tokens * 0.001 * pricing_dict[model]
|
241 |
+
|
242 |
+
|
243 |
+
def get_response_cost(model: str, num_tokens: int) -> float:
|
244 |
+
# good enough for quick evals
|
245 |
+
if model not in pricing_dict:
|
246 |
+
return num_tokens * 0.001 * pricing_dict["gpt-4"]
|
247 |
+
|
248 |
+
cost = num_tokens * 0.001 * pricing_dict[model]
|
249 |
+
|
250 |
+
if model == "gpt-4":
|
251 |
+
cost *= 2
|
252 |
+
|
253 |
+
return cost
|
254 |
+
|
255 |
+
|
256 |
+
def record_messages(messages: list[dict], model: str):
|
257 |
+
# create the conversation in a human-readable format
|
258 |
+
conversation_text = ""
|
259 |
+
for message in messages:
|
260 |
+
conversation_text += message["content"]
|
261 |
+
|
262 |
+
# write the conversation to the next available text file
|
263 |
+
with open(f"gpt_outputs/transcript.txt", "w") as f:
|
264 |
+
f.write(model + "\n\n")
|
265 |
+
f.write(conversation_text)
|
chess-gpt-eval/llama_module.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
2 |
+
from peft import PeftModel
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
|
8 |
+
# There are a couple non optimal parts of this code:
|
9 |
+
# 1. It doesn't inherit the Player class in main.py, which throws type checking errors
|
10 |
+
# 2. get_move_from_response() is duplicated from main.py
|
11 |
+
# However, I didn't want to add clutter and major dependencies like torch, peft, and transformers
|
12 |
+
# to those not using this class. So, this was my compromise.
|
13 |
+
class BaseLlamaPlayer:
|
14 |
+
def __init__(
|
15 |
+
self, tokenizer: AutoTokenizer, model: AutoModelForCausalLM, model_name: str
|
16 |
+
):
|
17 |
+
self.tokenizer = tokenizer
|
18 |
+
self.model = model
|
19 |
+
self.model_name = model_name
|
20 |
+
|
21 |
+
def get_llama_response(self, game_state: str, temperature: float) -> Optional[str]:
|
22 |
+
prompt = game_state
|
23 |
+
tokenized_input = self.tokenizer(prompt, return_tensors="pt").to("cuda")
|
24 |
+
result = self.model.generate(
|
25 |
+
**tokenized_input, max_new_tokens=10, temperature=temperature
|
26 |
+
).to("cpu")
|
27 |
+
input_ids_tensor = tokenized_input["input_ids"]
|
28 |
+
# transformers generate() returns <s> + prompt + output. This grabs only the output
|
29 |
+
res_sliced = result[:, input_ids_tensor.shape[1] :]
|
30 |
+
return self.tokenizer.batch_decode(res_sliced)[0]
|
31 |
+
|
32 |
+
def get_move_from_response(self, response: Optional[str]) -> Optional[str]:
|
33 |
+
if response is None:
|
34 |
+
return None
|
35 |
+
|
36 |
+
# Parse the response to get only the first move
|
37 |
+
moves = response.split()
|
38 |
+
first_move = moves[0] if moves else None
|
39 |
+
|
40 |
+
return first_move
|
41 |
+
|
42 |
+
def get_move(
|
43 |
+
self, board: str, game_state: str, temperature: float
|
44 |
+
) -> Optional[str]:
|
45 |
+
completion = self.get_llama_response(game_state, temperature)
|
46 |
+
return self.get_move_from_response(completion)
|
47 |
+
|
48 |
+
def get_config(self) -> dict:
|
49 |
+
return {"model": self.model_name}
|
50 |
+
|
51 |
+
|
52 |
+
class LocalLlamaPlayer(BaseLlamaPlayer):
|
53 |
+
def __init__(self, model_name: str):
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
55 |
+
model = AutoModelForCausalLM.from_pretrained(
|
56 |
+
model_name, torch_dtype=torch.bfloat16, device_map=0
|
57 |
+
).to("cuda")
|
58 |
+
super().__init__(tokenizer, model, model_name)
|
59 |
+
|
60 |
+
|
61 |
+
class LocalLoraLlamaPlayer(BaseLlamaPlayer):
|
62 |
+
def __init__(self, base_model_id: str, adapter_model_path: str):
|
63 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
|
64 |
+
base_model = AutoModelForCausalLM.from_pretrained(base_model_id)
|
65 |
+
model = (
|
66 |
+
PeftModel.from_pretrained(base_model, adapter_model_path)
|
67 |
+
.merge_and_unload()
|
68 |
+
.to("cuda")
|
69 |
+
)
|
70 |
+
|
71 |
+
super().__init__(tokenizer, model, adapter_model_path)
|
chess-gpt-eval/main.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import chess
|
3 |
+
import chess.engine
|
4 |
+
import os
|
5 |
+
import csv
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
import platform
|
9 |
+
|
10 |
+
# NOTE: LLAMA AND NANOGPT ARE EXPERIMENTAL PLAYERS, if not using them, comment them out
|
11 |
+
# from llama_module import BaseLlamaPlayer, LocalLlamaPlayer, LocalLoraLlamaPlayer
|
12 |
+
from nanogpt.nanogpt_module import NanoGptPlayer
|
13 |
+
from mamba_module import MambaPlayer
|
14 |
+
import gpt_query
|
15 |
+
from lczero.backends import Weights, Backend, GameState
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from typing import Optional, Tuple
|
19 |
+
from dataclasses import dataclass
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class LegalMoveResponse:
|
24 |
+
move_san: Optional[str] = None
|
25 |
+
move_uci: Optional[chess.Move] = None
|
26 |
+
attempts: int = 0
|
27 |
+
is_resignation: bool = False
|
28 |
+
is_illegal_move: bool = False
|
29 |
+
|
30 |
+
|
31 |
+
# Define base Player class
|
32 |
+
class Player:
|
33 |
+
def get_move(self, board: chess.Board, game_state: str, temperature: float) -> str:
|
34 |
+
raise NotImplementedError
|
35 |
+
|
36 |
+
def get_config(self) -> dict:
|
37 |
+
raise NotImplementedError
|
38 |
+
|
39 |
+
|
40 |
+
class GPTPlayer(Player):
|
41 |
+
def __init__(self, model: str):
|
42 |
+
with open("gpt_inputs/api_key.txt", "r") as f:
|
43 |
+
openai.api_key = f.read().strip()
|
44 |
+
self.model = model
|
45 |
+
|
46 |
+
def get_move(
|
47 |
+
self, board: chess.Board, game_state: str, temperature: float
|
48 |
+
) -> Optional[str]:
|
49 |
+
response = get_gpt_response(game_state, self.model, temperature)
|
50 |
+
return get_move_from_gpt_response(response)
|
51 |
+
|
52 |
+
def get_config(self) -> dict:
|
53 |
+
return {"model": self.model}
|
54 |
+
|
55 |
+
|
56 |
+
class LC0PLayer(Player):
|
57 |
+
# "11258-32x4-se.pb.gz" = stockfish level 0- = skill 0
|
58 |
+
# "11258-48x5-se.pb.gz" = stockfish level 0+ = skill 1
|
59 |
+
# "11258-80x7-se.pb.gz" = stockfish level 1 = skill 2
|
60 |
+
# "11258-104x9-se.pb.gz" = stockfish level 2 = skill 3
|
61 |
+
# "TK-6430 aka 128x10-BPR-64M-6430000.pb.gz" = stockfish level 3 = skill 4
|
62 |
+
# "00af53b081e80147172e6f281c01daf5ca19ada173321438914c730370aa4267" = stockfish level 4 = skill 5
|
63 |
+
# "b2ec465d0fb5b5eb39d2e1e3f74041a5d2fc92d413b71aa7ea0b6fb082ccba9c" = stockfish level 5+ = skill 6
|
64 |
+
def __init__(self, skill):
|
65 |
+
self.skill = skill
|
66 |
+
network_paths = ["./lc0/build/release/11258-32x4-se.pb.gz", "./lc0/build/release/11258-48x5-se.pb.gz", "./lc0/build/release/11258-80x7-se.pb.gz", "./lc0/build/release/11258-104x9-se.pb.gz", "./lc0/build/release/TK-6430 aka 128x10-BPR-64M-6430000.pb.gz", "./lc0/build/release/00af53b081e80147172e6f281c01daf5ca19ada173321438914c730370aa4267", "./lc0/build/release/b2ec465d0fb5b5eb39d2e1e3f74041a5d2fc92d413b71aa7ea0b6fb082ccba9c"]
|
67 |
+
print(f"\n\nLoading lc0 network: {network_paths[skill]}\n\n")
|
68 |
+
self.weights = Weights(network_paths[skill])
|
69 |
+
self.backend = Backend(weights=self.weights)
|
70 |
+
self.gamestate = GameState()
|
71 |
+
|
72 |
+
def get_move(self, board: chess.Board, game_state: str, temperature: float):
|
73 |
+
self.gamestate = GameState(fen=board.fen())
|
74 |
+
input_planes = self.gamestate.as_input(self.backend)
|
75 |
+
result = self.backend.evaluate(input_planes)[0]
|
76 |
+
moves = self.gamestate.moves()
|
77 |
+
policy_indices = self.gamestate.policy_indices()
|
78 |
+
move_probs = np.array(result.p_softmax(*policy_indices))
|
79 |
+
best_move_idx = move_probs.argmax()
|
80 |
+
best_move = moves[best_move_idx]
|
81 |
+
return board.san(chess.Move.from_uci(best_move))
|
82 |
+
|
83 |
+
def get_config(self) -> dict:
|
84 |
+
return {"network": self.weights, "skill_level": self.skill, "play_time": 0}
|
85 |
+
|
86 |
+
|
87 |
+
class StockfishPlayer(Player):
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def get_stockfish_path() -> str:
|
91 |
+
"""
|
92 |
+
Determines the operating system and returns the appropriate path for Stockfish.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
str: Path to the Stockfish executable based on the operating system.
|
96 |
+
"""
|
97 |
+
if platform.system() == 'Linux':
|
98 |
+
return "/usr/games/stockfish"
|
99 |
+
elif platform.system() == 'Darwin': # Darwin is the system name for macOS
|
100 |
+
return "stockfish"
|
101 |
+
elif platform.system() == 'Windows':
|
102 |
+
return r"C:\Users\Haile\Downloads\stockfish\stockfish-windows-x86-64-avx2.exe"
|
103 |
+
else:
|
104 |
+
raise OSError("Unsupported operating system")
|
105 |
+
|
106 |
+
def __init__(self, skill_level: int, play_time: float):
|
107 |
+
self._skill_level = skill_level
|
108 |
+
self._play_time = play_time
|
109 |
+
# If getting started, you need to run brew install stockfish
|
110 |
+
stockfish_path = StockfishPlayer.get_stockfish_path()
|
111 |
+
self._engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
112 |
+
|
113 |
+
def get_move(
|
114 |
+
self, board: chess.Board, game_state: str, temperature: float
|
115 |
+
) -> Optional[str]:
|
116 |
+
if self._skill_level == -2:
|
117 |
+
legal_moves = list(board.legal_moves)
|
118 |
+
random_move = random.choice(legal_moves)
|
119 |
+
return board.san(random_move)
|
120 |
+
elif self._skill_level < 0:
|
121 |
+
self._engine.configure({"Skill Level": 0})
|
122 |
+
result = self._engine.play(
|
123 |
+
board, chess.engine.Limit(time=1e-8, depth=1, nodes=1)
|
124 |
+
)
|
125 |
+
|
126 |
+
else:
|
127 |
+
self._engine.configure({"Skill Level": self._skill_level})
|
128 |
+
result = self._engine.play(board, chess.engine.Limit(time=self._play_time))
|
129 |
+
if result.move is None:
|
130 |
+
return None
|
131 |
+
return board.san(result.move)
|
132 |
+
|
133 |
+
def get_config(self) -> dict:
|
134 |
+
return {"skill_level": self._skill_level, "play_time": self._play_time}
|
135 |
+
|
136 |
+
def close(self):
|
137 |
+
self._engine.quit()
|
138 |
+
|
139 |
+
|
140 |
+
class HumanPlayer(Player):
|
141 |
+
def get_move(self, board: chess.Board, game_state: str, temperature: float) -> str:
|
142 |
+
# Print board for human player
|
143 |
+
print(board)
|
144 |
+
while True:
|
145 |
+
move = input("Enter your move (SAN format): ")
|
146 |
+
try:
|
147 |
+
move_uci = board.parse_san(move)
|
148 |
+
if move_uci in board.legal_moves:
|
149 |
+
return move
|
150 |
+
except:
|
151 |
+
print("Illegal move, try again.")
|
152 |
+
|
153 |
+
def get_config(self) -> dict:
|
154 |
+
return {"player": "human"}
|
155 |
+
|
156 |
+
|
157 |
+
def get_gpt_response(game_state: str, model: str, temperature: float) -> Optional[str]:
|
158 |
+
# trying to prevent what I believe to be rate limit issues
|
159 |
+
if model == "gpt-4":
|
160 |
+
time.sleep(0.4)
|
161 |
+
response = gpt_query.get_gpt_response(game_state, model, temperature)
|
162 |
+
return response
|
163 |
+
|
164 |
+
|
165 |
+
def get_move_from_gpt_response(response: Optional[str]) -> Optional[str]:
|
166 |
+
if response is None:
|
167 |
+
return None
|
168 |
+
|
169 |
+
# Parse the response to get only the first move
|
170 |
+
moves = response.split()
|
171 |
+
first_move = moves[0] if moves else None
|
172 |
+
|
173 |
+
return first_move
|
174 |
+
|
175 |
+
|
176 |
+
def record_results(
|
177 |
+
board: chess.Board,
|
178 |
+
player_one: Player,
|
179 |
+
player_two: Player,
|
180 |
+
game_state: str,
|
181 |
+
player_one_illegal_moves: int,
|
182 |
+
player_two_illegal_moves: int,
|
183 |
+
player_one_legal_moves: int,
|
184 |
+
player_two_legal_moves: int,
|
185 |
+
total_time: float,
|
186 |
+
player_one_resignation: bool,
|
187 |
+
player_two_resignation: bool,
|
188 |
+
player_one_failed_to_find_legal_move: bool,
|
189 |
+
player_two_failed_to_find_legal_move: bool,
|
190 |
+
total_moves: int,
|
191 |
+
illegal_moves: int,
|
192 |
+
):
|
193 |
+
unique_game_id = generate_unique_game_id()
|
194 |
+
|
195 |
+
(
|
196 |
+
player_one_title,
|
197 |
+
player_two_title,
|
198 |
+
player_one_time,
|
199 |
+
player_two_time,
|
200 |
+
) = get_player_titles_and_time(player_one, player_two)
|
201 |
+
|
202 |
+
if player_one_resignation or player_one_failed_to_find_legal_move:
|
203 |
+
result = "0-1"
|
204 |
+
player_one_score = 0
|
205 |
+
player_two_score = 1
|
206 |
+
elif player_two_resignation or player_two_failed_to_find_legal_move:
|
207 |
+
result = "1-0"
|
208 |
+
player_one_score = 1
|
209 |
+
player_two_score = 0
|
210 |
+
else:
|
211 |
+
result = board.result()
|
212 |
+
# Hmmm.... debating this one. Annoying if I leave it running and it fails here for some reason, probably involving some
|
213 |
+
# resignation / failed move situation I didn't think of
|
214 |
+
# -1e10 at least ensures it doesn't fail silently
|
215 |
+
if "-" in result:
|
216 |
+
player_one_score = result.split("-")[0]
|
217 |
+
player_two_score = result.split("-")[1]
|
218 |
+
elif result == "*": # Draw due to hitting max moves
|
219 |
+
player_one_score = 0#1/2
|
220 |
+
player_two_score = 1#1/2
|
221 |
+
else:
|
222 |
+
player_one_score = -1e10
|
223 |
+
player_two_score = -1e10
|
224 |
+
|
225 |
+
info_dict = {
|
226 |
+
"game_id": unique_game_id,
|
227 |
+
"transcript": game_state,
|
228 |
+
"result": result,
|
229 |
+
"player_one": player_one_title,
|
230 |
+
"player_two": player_two_title,
|
231 |
+
"player_one_time": player_one_time,
|
232 |
+
"player_two_time": player_two_time,
|
233 |
+
"player_one_score": player_one_score,
|
234 |
+
"player_two_score": player_two_score,
|
235 |
+
"player_one_illegal_moves": player_one_illegal_moves,
|
236 |
+
"player_two_illegal_moves": player_two_illegal_moves,
|
237 |
+
"player_one_legal_moves": player_one_legal_moves,
|
238 |
+
"player_two_legal_moves": player_two_legal_moves,
|
239 |
+
"player_one_resignation": player_one_resignation,
|
240 |
+
"player_two_resignation": player_two_resignation,
|
241 |
+
"player_one_failed_to_find_legal_move": player_one_failed_to_find_legal_move,
|
242 |
+
"player_two_failed_to_find_legal_move": player_two_failed_to_find_legal_move,
|
243 |
+
"game_title": f"{player_one_title} vs. {player_two_title}",
|
244 |
+
"number_of_moves": board.fullmove_number,
|
245 |
+
"time_taken": total_time,
|
246 |
+
"total_moves": total_moves,
|
247 |
+
"illegal_moves": illegal_moves,
|
248 |
+
}
|
249 |
+
|
250 |
+
if RUN_FOR_ANALYSIS:
|
251 |
+
csv_file_path = f"logs/{player_one_recording_name}_vs_{player_two_recording_name}"
|
252 |
+
csv_file_path = csv_file_path.replace(".", "_") # Because I'm using ckpt filenames for nanogpt models
|
253 |
+
csv_file_path += ".csv"
|
254 |
+
else:
|
255 |
+
csv_file_path = recording_file
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
# Determine if we need to write headers (in case the file doesn't exist yet)
|
260 |
+
write_headers = not os.path.exists(csv_file_path)
|
261 |
+
|
262 |
+
# Append the results to the CSV file
|
263 |
+
with open(csv_file_path, "a", newline="") as csv_file: # THIS WAS APPEND
|
264 |
+
writer = csv.DictWriter(csv_file, fieldnames=info_dict.keys())
|
265 |
+
if write_headers:
|
266 |
+
writer.writeheader()
|
267 |
+
writer.writerow(info_dict)
|
268 |
+
|
269 |
+
with open("game.txt", "w") as f:
|
270 |
+
f.write(game_state)
|
271 |
+
|
272 |
+
|
273 |
+
def generate_unique_game_id() -> str:
|
274 |
+
timestamp = int(time.time())
|
275 |
+
random_num = random.randint(1000, 9999) # 4-digit random number
|
276 |
+
return f"{timestamp}-{random_num}"
|
277 |
+
|
278 |
+
|
279 |
+
def get_player_titles_and_time(
|
280 |
+
player_one: Player, player_two: Player
|
281 |
+
) -> Tuple[str, str, Optional[float], Optional[float]]:
|
282 |
+
player_one_config = player_one.get_config()
|
283 |
+
player_two_config = player_two.get_config()
|
284 |
+
|
285 |
+
# For player one
|
286 |
+
if "model" in player_one_config:
|
287 |
+
player_one_title = player_one_config["model"]
|
288 |
+
player_one_time = None
|
289 |
+
else:
|
290 |
+
player_one_title = f"Stockfish {player_one_config['skill_level']}"
|
291 |
+
player_one_time = player_one_config["play_time"]
|
292 |
+
|
293 |
+
# For player two
|
294 |
+
if "model" in player_two_config:
|
295 |
+
player_two_title = player_two_config["model"]
|
296 |
+
player_two_time = None
|
297 |
+
else:
|
298 |
+
player_two_title = f"Stockfish {player_two_config['skill_level']}"
|
299 |
+
player_two_time = player_two_config["play_time"]
|
300 |
+
|
301 |
+
return (player_one_title, player_two_title, player_one_time, player_two_time)
|
302 |
+
|
303 |
+
|
304 |
+
used_openings = []
|
305 |
+
def initialize_game_with_opening(
|
306 |
+
game_state: str, board: chess.Board
|
307 |
+
) -> Tuple[str, chess.Board]:
|
308 |
+
global used_openings
|
309 |
+
with open("openings.csv", "r") as file:
|
310 |
+
lines = file.readlines()[1:] # Skip header
|
311 |
+
moves_string = random.choice(lines)
|
312 |
+
while moves_string in used_openings:
|
313 |
+
moves_string = random.choice(lines)
|
314 |
+
used_openings.append(moves_string)
|
315 |
+
if move_num_in_gamestate:
|
316 |
+
game_state = moves_string.rstrip() + " "
|
317 |
+
else:
|
318 |
+
game_state = ' '.join(['.' + m.split(".")[-1] if "." in m else m for m in moves_string.split()])
|
319 |
+
game_state = game_state.rstrip() + " "
|
320 |
+
# Splitting the moves string on spaces
|
321 |
+
tokens = moves_string.split()
|
322 |
+
|
323 |
+
for token in tokens:
|
324 |
+
# If the token contains a period, it's a move number + move combination
|
325 |
+
if "." in token:
|
326 |
+
move = token.split(".")[-1] # Take the move part after the period
|
327 |
+
else:
|
328 |
+
move = token
|
329 |
+
|
330 |
+
board.push_san(move)
|
331 |
+
return game_state.rstrip(), board
|
332 |
+
|
333 |
+
|
334 |
+
# Return is (move_san, move_uci, attempts, is_resignation, is_illegal_move)
|
335 |
+
def get_legal_move(
|
336 |
+
player: Player,
|
337 |
+
board: chess.Board,
|
338 |
+
game_state: str,
|
339 |
+
player_one: bool,
|
340 |
+
max_attempts: int = 5,
|
341 |
+
) -> LegalMoveResponse:
|
342 |
+
"""Request a move from the player and ensure it's legal."""
|
343 |
+
move_san = None
|
344 |
+
move_uci = None
|
345 |
+
|
346 |
+
for attempt in range(max_attempts):
|
347 |
+
#print(f"get_legal_move: |{game_state}|")
|
348 |
+
move_san = player.get_move(
|
349 |
+
board, game_state, min(((attempt / max_attempts) * 1) + 0.001, 0.75)
|
350 |
+
)
|
351 |
+
|
352 |
+
# Sometimes when GPT thinks it's the end of the game, it will just output the result
|
353 |
+
# Like "1-0". If so, this really isn't an illegal move, so we'll add a check for that.
|
354 |
+
if move_san is not None:
|
355 |
+
if move_san == "1-0" or move_san == "0-1" or move_san == "1/2-1/2":
|
356 |
+
print(f"{move_san}, player has resigned")
|
357 |
+
return LegalMoveResponse(
|
358 |
+
move_san=None,
|
359 |
+
move_uci=None,
|
360 |
+
attempts=attempt,
|
361 |
+
is_resignation=True,
|
362 |
+
)
|
363 |
+
|
364 |
+
try:
|
365 |
+
move_uci = board.parse_san(move_san)
|
366 |
+
except Exception as e:
|
367 |
+
print(f"Error parsing move {move_san}: {e}")
|
368 |
+
# check if player is gpt-3.5-turbo-instruct
|
369 |
+
# only recording errors for gpt-3.5-turbo-instruct because it's errors are so rare
|
370 |
+
if player.get_config()["model"] == "gpt-3.5-turbo-instruct":
|
371 |
+
with open("gpt-3.5-turbo-instruct-illegal-moves.txt", "a") as f:
|
372 |
+
f.write(f"{game_state}\n{move_san}\n")
|
373 |
+
continue
|
374 |
+
|
375 |
+
if move_uci in board.legal_moves:
|
376 |
+
if player_one == False:
|
377 |
+
if not move_san.startswith(" "):
|
378 |
+
move_san = " " + move_san
|
379 |
+
else:
|
380 |
+
if move_san.startswith(" "):
|
381 |
+
move_san = move_san[1:]
|
382 |
+
return LegalMoveResponse(move_san, move_uci, attempt)
|
383 |
+
print(f"Illegal move: {move_san}")
|
384 |
+
|
385 |
+
# If we reach here, the player has made illegal moves for all attempts.
|
386 |
+
print(f"{player} provided illegal moves for {max_attempts} attempts.")
|
387 |
+
return LegalMoveResponse(
|
388 |
+
move_san=None, move_uci=None, attempts=max_attempts, is_illegal_move=True
|
389 |
+
)
|
390 |
+
|
391 |
+
|
392 |
+
def play_turn(
|
393 |
+
player: Player, board: chess.Board, game_state: str, player_one: bool
|
394 |
+
) -> Tuple[str, bool, bool, int]:
|
395 |
+
result = get_legal_move(player, board, game_state, player_one, 5)
|
396 |
+
illegal_moves = result.attempts
|
397 |
+
move_san = result.move_san
|
398 |
+
move_uci = result.move_uci
|
399 |
+
resignation = result.is_resignation
|
400 |
+
failed_to_find_legal_move = result.is_illegal_move
|
401 |
+
|
402 |
+
if resignation:
|
403 |
+
print(f"{player} resigned with result: {board.result()}")
|
404 |
+
elif failed_to_find_legal_move:
|
405 |
+
print(f"Game over: 5 consecutive illegal moves from {player}")
|
406 |
+
elif move_san is None or move_uci is None:
|
407 |
+
print(f"Game over: {player} failed to find a legal move")
|
408 |
+
else:
|
409 |
+
board.push(move_uci)
|
410 |
+
game_state += move_san
|
411 |
+
print(move_san, end=" ")
|
412 |
+
|
413 |
+
return game_state, resignation, failed_to_find_legal_move, illegal_moves
|
414 |
+
|
415 |
+
|
416 |
+
def play_game(
|
417 |
+
player_one: Player,
|
418 |
+
player_two: Player,
|
419 |
+
max_games: int = 10,
|
420 |
+
random_opening_seed: bool = False,
|
421 |
+
):
|
422 |
+
for z in range(max_games):
|
423 |
+
print(f"\nGame {z} of {max_games}\n")
|
424 |
+
|
425 |
+
with open("gpt_inputs/prompt.txt", "r") as f:
|
426 |
+
game_state = f.read()
|
427 |
+
board = chess.Board()
|
428 |
+
|
429 |
+
if random_opening_seed:
|
430 |
+
game_state, board = initialize_game_with_opening(game_state, board)
|
431 |
+
#print(f"play_gamea after init: |{game_state}|")
|
432 |
+
player_one_illegal_moves = 0
|
433 |
+
player_two_illegal_moves = 0
|
434 |
+
player_one_legal_moves = 0
|
435 |
+
player_two_legal_moves = 0
|
436 |
+
player_one_resignation = False
|
437 |
+
player_two_resignation = False
|
438 |
+
player_one_failed_to_find_legal_move = False
|
439 |
+
player_two_failed_to_find_legal_move = False
|
440 |
+
start_time = time.time()
|
441 |
+
|
442 |
+
total_moves = 0
|
443 |
+
illegal_moves = 0
|
444 |
+
print_for_human = isinstance(player_one, HumanPlayer) or isinstance(player_two, HumanPlayer)
|
445 |
+
|
446 |
+
while not board.is_game_over():
|
447 |
+
if print_for_human:
|
448 |
+
print(board)
|
449 |
+
|
450 |
+
with open("game.txt", "w") as f:
|
451 |
+
f.write(game_state)
|
452 |
+
current_move_num = f"{board.fullmove_number if move_num_in_gamestate else ''}."
|
453 |
+
total_moves += 1
|
454 |
+
# I increment legal moves here so player_two isn't penalized for the game ending before its turn
|
455 |
+
player_one_legal_moves += 1
|
456 |
+
player_two_legal_moves += 1
|
457 |
+
|
458 |
+
# this if statement may be overkill, just trying to get format to exactly match PGN notation
|
459 |
+
if board.fullmove_number != 1:
|
460 |
+
game_state += " "
|
461 |
+
game_state += current_move_num
|
462 |
+
#print(f"|{game_state}|")
|
463 |
+
#print(f"{current_move_num}", end=" ")
|
464 |
+
|
465 |
+
(
|
466 |
+
game_state,
|
467 |
+
player_one_resignation,
|
468 |
+
player_one_failed_to_find_legal_move,
|
469 |
+
illegal_moves_one,
|
470 |
+
) = play_turn(player_one, board, game_state, player_one=True)
|
471 |
+
player_one_illegal_moves += illegal_moves_one
|
472 |
+
if illegal_moves_one != 0:
|
473 |
+
player_one_legal_moves -= 1
|
474 |
+
if (
|
475 |
+
board.is_game_over()
|
476 |
+
or player_one_resignation
|
477 |
+
or player_one_failed_to_find_legal_move
|
478 |
+
):
|
479 |
+
break
|
480 |
+
|
481 |
+
(
|
482 |
+
game_state,
|
483 |
+
player_two_resignation,
|
484 |
+
player_two_failed_to_find_legal_move,
|
485 |
+
illegal_moves_two,
|
486 |
+
) = play_turn(player_two, board, game_state, player_one=False)
|
487 |
+
player_two_illegal_moves += illegal_moves_two
|
488 |
+
if illegal_moves_two != 0:
|
489 |
+
player_two_legal_moves -= 1
|
490 |
+
if (
|
491 |
+
board.is_game_over()
|
492 |
+
or player_two_resignation
|
493 |
+
or player_two_failed_to_find_legal_move
|
494 |
+
):
|
495 |
+
break
|
496 |
+
|
497 |
+
print("\n", end="")
|
498 |
+
|
499 |
+
if total_moves > MAX_MOVES:
|
500 |
+
break
|
501 |
+
|
502 |
+
end_time = time.time()
|
503 |
+
total_time = end_time - start_time
|
504 |
+
print(f"\nGame over. Total time: {total_time} seconds")
|
505 |
+
print(f"Result: {board.result()}")
|
506 |
+
print(board)
|
507 |
+
print()
|
508 |
+
record_results(
|
509 |
+
board,
|
510 |
+
player_one,
|
511 |
+
player_two,
|
512 |
+
game_state,
|
513 |
+
player_one_illegal_moves,
|
514 |
+
player_two_illegal_moves,
|
515 |
+
player_one_legal_moves,
|
516 |
+
player_two_legal_moves,
|
517 |
+
total_time,
|
518 |
+
player_one_resignation,
|
519 |
+
player_two_resignation,
|
520 |
+
player_one_failed_to_find_legal_move,
|
521 |
+
player_two_failed_to_find_legal_move,
|
522 |
+
total_moves,
|
523 |
+
illegal_moves,
|
524 |
+
)
|
525 |
+
if isinstance(player_one, StockfishPlayer):
|
526 |
+
player_one.close()
|
527 |
+
if isinstance(player_two, StockfishPlayer):
|
528 |
+
player_two.close()
|
529 |
+
|
530 |
+
# print(game_state)
|
531 |
+
|
532 |
+
|
533 |
+
RUN_FOR_ANALYSIS = True
|
534 |
+
MAX_MOVES = 999 # Due to nanogpt max input length of 1024
|
535 |
+
recording_file = "logs/determine.csv" # default recording file. Because we are using list [player_ones], recording_file is overwritten
|
536 |
+
# player_one_recording_name = "ckpt_8.pt"
|
537 |
+
#player_ones = ["ckpt_iter_20000.pt","ckpt_iter_40000.pt","ckpt_iter_60000.pt","ckpt_iter_80000.pt"] #["ckpt.pt"]
|
538 |
+
player_ones = ["Xformer/6.6M/ckpt.pt"]
|
539 |
+
player_two_recording_name = "lc0_sweep" #"stockfish_sweep"
|
540 |
+
move_num_in_gamestate = False
|
541 |
+
if __name__ == "__main__":
|
542 |
+
for nanogpt_player in player_ones:
|
543 |
+
player_one_recording_name = nanogpt_player
|
544 |
+
for i in range(2): #range(11):
|
545 |
+
num_games = 265 #265 instead of 250 for duplicates (for lc0, stockfish doesn't need it)
|
546 |
+
# player_one = GPTPlayer(model="gpt-3.5-turbo-instruct")
|
547 |
+
# player_one = LocalLlamaPlayer(model_name="meta-llama/Llama-2-7b-hf")
|
548 |
+
# player_one = LocalLoraLlamaPlayer("meta-llama/Llama-2-7b-hf", "/workspace/axolotl/lora2-out")
|
549 |
+
# player_one = GPTPlayer(model="gpt-4")
|
550 |
+
# player_one = StockfishPlayer(skill_level=-1, play_time=0.1)
|
551 |
+
|
552 |
+
player_one = NanoGptPlayer(model_name=player_one_recording_name, move_num_in_gamestate=False)
|
553 |
+
# player_one = MambaPlayer(model_name=player_one_recording_name, move_num_in_gamestate=False)
|
554 |
+
#player_two = StockfishPlayer(skill_level=i, play_time=0.1)
|
555 |
+
player_two = LC0PLayer(skill=i)
|
556 |
+
|
557 |
+
# player_two = GPTPlayer(model="gpt-4")
|
558 |
+
# player_two = GPTPlayer(model="gpt-3.5-turbo-instruct")
|
559 |
+
|
560 |
+
print(f"\n\nSTARTING GAMES AGAINST STOCKFISH LEVEL {i}\n\n")
|
561 |
+
#print(f"\n\nSTARTING GAMES AGAINST LC0 LEVEL {i}\n\n")
|
562 |
+
|
563 |
+
play_game(player_one, player_two, num_games, random_opening_seed=True)
|
564 |
+
|
565 |
+
print("\n\n\n********\nDONE!\n********\n\n\n")
|
chess-gpt-eval/mamba.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from pscan import pscan
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
This file closely follows the mamba_simple.py from the official Mamba implementation, and the mamba-minimal by @johnma2006.
|
14 |
+
The major differences are :
|
15 |
+
-the convolution is done with torch.nn.Conv1d
|
16 |
+
-the selective scan is done in PyTorch
|
17 |
+
|
18 |
+
A sequential version of the selective scan is also available for comparison.
|
19 |
+
|
20 |
+
- A Mamba model is composed of several layers, which are ResidualBlock.
|
21 |
+
- A ResidualBlock is composed of a MambaBlock, a normalization, and a residual connection : ResidualBlock(x) = mamba(norm(x)) + x
|
22 |
+
- This leaves us with the MambaBlock : its input x is (B, L, D) and its outputs y is also (B, L, D) (B=batch size, L=seq len, D=model dim).
|
23 |
+
First, we expand x into (B, L, 2*ED) (where E is usually 2) and split it into x and z, each (B, L, ED).
|
24 |
+
Then, we apply the short 1d conv to x, followed by an activation function (silu), then the SSM.
|
25 |
+
We then multiply it by silu(z).
|
26 |
+
See Figure 3 of the paper (page 8) for a visual representation of a MambaBlock.
|
27 |
+
|
28 |
+
"""
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class MambaConfig:
|
32 |
+
d_model: int # D
|
33 |
+
n_layers: int
|
34 |
+
dt_rank: Union[int, str] = 'auto'
|
35 |
+
d_state: int = 16 # N in paper/comments
|
36 |
+
expand_factor: int = 2 # E in paper/comments
|
37 |
+
d_conv: int = 4
|
38 |
+
|
39 |
+
dt_min: float = 0.001
|
40 |
+
dt_max: float = 0.1
|
41 |
+
dt_init: str = "random" # "random" or "constant"
|
42 |
+
dt_scale: float = 1.0
|
43 |
+
dt_init_floor = 1e-4
|
44 |
+
|
45 |
+
bias: bool = False
|
46 |
+
conv_bias: bool = True
|
47 |
+
|
48 |
+
pscan: bool = True # use parallel scan mode or sequential mode when training
|
49 |
+
|
50 |
+
def __post_init__(self):
|
51 |
+
self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
|
52 |
+
|
53 |
+
if self.dt_rank == 'auto':
|
54 |
+
self.dt_rank = math.ceil(self.d_model / 16)
|
55 |
+
|
56 |
+
class Mamba(nn.Module):
|
57 |
+
def __init__(self, config: MambaConfig):
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
self.config = config
|
61 |
+
|
62 |
+
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
|
63 |
+
#self.norm_f = RMSNorm(config.d_model)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
# x : (B, L, D)
|
67 |
+
|
68 |
+
# y : (B, L, D)
|
69 |
+
|
70 |
+
for layer in self.layers:
|
71 |
+
x = layer(x)
|
72 |
+
|
73 |
+
#x = self.norm_f(x)
|
74 |
+
|
75 |
+
return x
|
76 |
+
|
77 |
+
def step(self, x, caches):
|
78 |
+
# x : (B, L, D)
|
79 |
+
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
80 |
+
|
81 |
+
# y : (B, L, D)
|
82 |
+
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
83 |
+
|
84 |
+
for i, layer in enumerate(self.layers):
|
85 |
+
x, caches[i] = layer.step(x, caches[i])
|
86 |
+
|
87 |
+
return x, caches
|
88 |
+
|
89 |
+
class ResidualBlock(nn.Module):
|
90 |
+
def __init__(self, config: MambaConfig):
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
self.mixer = MambaBlock(config)
|
94 |
+
self.norm = RMSNorm(config.d_model)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
# x : (B, L, D)
|
98 |
+
|
99 |
+
# output : (B, L, D)
|
100 |
+
|
101 |
+
output = self.mixer(self.norm(x)) + x
|
102 |
+
return output
|
103 |
+
|
104 |
+
def step(self, x, cache):
|
105 |
+
# x : (B, D)
|
106 |
+
# cache : (h, inputs)
|
107 |
+
# h : (B, ED, N)
|
108 |
+
# inputs: (B, ED, d_conv-1)
|
109 |
+
|
110 |
+
# output : (B, D)
|
111 |
+
# cache : (h, inputs)
|
112 |
+
|
113 |
+
output, cache = self.mixer.step(self.norm(x), cache)
|
114 |
+
output = output + x
|
115 |
+
return output, cache
|
116 |
+
|
117 |
+
class MambaBlock(nn.Module):
|
118 |
+
def __init__(self, config: MambaConfig):
|
119 |
+
super().__init__()
|
120 |
+
|
121 |
+
self.config = config
|
122 |
+
|
123 |
+
# projects block input from D to 2*ED (two branches)
|
124 |
+
self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
|
125 |
+
|
126 |
+
self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
|
127 |
+
kernel_size=config.d_conv, bias=config.conv_bias,
|
128 |
+
groups=config.d_inner,
|
129 |
+
padding=config.d_conv - 1)
|
130 |
+
|
131 |
+
nn.init.kaiming_normal_(self.conv1d.weight, mode='fan_out', nonlinearity='leaky_relu')
|
132 |
+
|
133 |
+
# projects x to input-dependent Δ, B, C
|
134 |
+
self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
|
135 |
+
|
136 |
+
# projects Δ from dt_rank to d_inner
|
137 |
+
self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
|
138 |
+
|
139 |
+
# dt initialization
|
140 |
+
# dt weights
|
141 |
+
dt_init_std = config.dt_rank**-0.5 * config.dt_scale
|
142 |
+
if config.dt_init == "constant":
|
143 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
144 |
+
elif config.dt_init == "random":
|
145 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
146 |
+
else:
|
147 |
+
raise NotImplementedError
|
148 |
+
|
149 |
+
# dt bias
|
150 |
+
dt = torch.exp(
|
151 |
+
torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min)
|
152 |
+
).clamp(min=config.dt_init_floor)
|
153 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
154 |
+
with torch.no_grad():
|
155 |
+
self.dt_proj.bias.copy_(inv_dt)
|
156 |
+
#self.dt_proj.bias._no_reinit = True # initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
157 |
+
# todo : explain why removed
|
158 |
+
|
159 |
+
# S4D real initialization
|
160 |
+
A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
|
161 |
+
self.A_log = nn.Parameter(torch.log(A)) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
|
162 |
+
self.D = nn.Parameter(torch.ones(config.d_inner))
|
163 |
+
|
164 |
+
# projects block output from ED back to D
|
165 |
+
self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
# x : (B, L, D)
|
169 |
+
|
170 |
+
# y : (B, L, D)
|
171 |
+
|
172 |
+
_, L, _ = x.shape
|
173 |
+
|
174 |
+
xz = self.in_proj(x) # (B, L, 2*ED)
|
175 |
+
x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)
|
176 |
+
|
177 |
+
# x branch
|
178 |
+
x = x.transpose(1, 2) # (B, ED, L)
|
179 |
+
x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
|
180 |
+
x = x.transpose(1, 2) # (B, L, ED)
|
181 |
+
|
182 |
+
x = F.silu(x)
|
183 |
+
y = self.ssm(x)
|
184 |
+
|
185 |
+
# z branch
|
186 |
+
z = F.silu(z)
|
187 |
+
|
188 |
+
output = y * z
|
189 |
+
output = self.out_proj(output) # (B, L, D)
|
190 |
+
|
191 |
+
return output
|
192 |
+
|
193 |
+
def ssm(self, x):
|
194 |
+
# x : (B, L, ED)
|
195 |
+
|
196 |
+
# y : (B, L, ED)
|
197 |
+
|
198 |
+
A = -torch.exp(self.A_log.float()) # (ED, N)
|
199 |
+
D = self.D.float()
|
200 |
+
# TODO remove .float()
|
201 |
+
|
202 |
+
deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N)
|
203 |
+
|
204 |
+
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
|
205 |
+
delta = F.softplus(self.dt_proj(delta)) # (B, L, ED)
|
206 |
+
|
207 |
+
if self.config.pscan:
|
208 |
+
y = self.selective_scan(x, delta, A, B, C, D)
|
209 |
+
else:
|
210 |
+
y = self.selective_scan_seq(x, delta, A, B, C, D)
|
211 |
+
|
212 |
+
return y
|
213 |
+
|
214 |
+
def selective_scan(self, x, delta, A, B, C, D):
|
215 |
+
# x : (B, L, ED)
|
216 |
+
# Δ : (B, L, ED)
|
217 |
+
# A : (ED, N)
|
218 |
+
# B : (B, L, N)
|
219 |
+
# C : (B, L, N)
|
220 |
+
# D : (ED)
|
221 |
+
|
222 |
+
# y : (B, L, ED)
|
223 |
+
|
224 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
|
225 |
+
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
|
226 |
+
|
227 |
+
BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
|
228 |
+
|
229 |
+
hs = pscan(deltaA, BX)
|
230 |
+
|
231 |
+
y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
|
232 |
+
|
233 |
+
y = y + D * x
|
234 |
+
|
235 |
+
return y
|
236 |
+
|
237 |
+
def selective_scan_seq(self, x, delta, A, B, C, D):
|
238 |
+
# x : (B, L, ED)
|
239 |
+
# Δ : (B, L, ED)
|
240 |
+
# A : (ED, N)
|
241 |
+
# B : (B, L, N)
|
242 |
+
# C : (B, L, N)
|
243 |
+
# D : (ED)
|
244 |
+
|
245 |
+
# y : (B, L, ED)
|
246 |
+
|
247 |
+
_, L, _ = x.shape
|
248 |
+
|
249 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
|
250 |
+
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
|
251 |
+
|
252 |
+
BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
|
253 |
+
|
254 |
+
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
|
255 |
+
hs = []
|
256 |
+
|
257 |
+
for t in range(0, L):
|
258 |
+
h = deltaA[:, t] * h + BX[:, t]
|
259 |
+
hs.append(h)
|
260 |
+
|
261 |
+
hs = torch.stack(hs, dim=1) # (B, L, ED, N)
|
262 |
+
|
263 |
+
y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
|
264 |
+
|
265 |
+
y = y + D * x
|
266 |
+
|
267 |
+
return y
|
268 |
+
|
269 |
+
# -------------------------- inference -------------------------- #
|
270 |
+
"""
|
271 |
+
Concerning auto-regressive inference
|
272 |
+
|
273 |
+
The cool part of using Mamba : inference is constant wrt to sequence length
|
274 |
+
We just have to keep in cache, for each layer, two things :
|
275 |
+
- the hidden state h (which is (B, ED, N)), as you typically would when doing inference with a RNN
|
276 |
+
- the last d_conv-1 inputs of the layer, to be able to compute the 1D conv which is a convolution over the time dimension
|
277 |
+
(d_conv is fixed so this doesn't incur a growing cache as we progress on generating the sequence)
|
278 |
+
(and d_conv is usually very small, like 4, so we just have to "remember" the last 3 inputs)
|
279 |
+
|
280 |
+
Concretely, these two quantities are put inside a cache tuple, and are named h and inputs respectively.
|
281 |
+
h is (B, ED, N), and inputs is (B, ED, d_conv-1)
|
282 |
+
The MambaBlock.step() receives this cache, and, along with outputing the output, alos outputs the updated cache for the next call.
|
283 |
+
|
284 |
+
The cache object is initialized as follows : (None, torch.zeros()).
|
285 |
+
When h is None, the selective scan function detects it and start with h=0.
|
286 |
+
The torch.zeros() isn't a problem (it's same as just feeding the input, because the conv1d is padded)
|
287 |
+
|
288 |
+
As we need one such cache variable per layer, we store a caches object, which is simply a list of cache object. (See mamba_lm.py)
|
289 |
+
"""
|
290 |
+
|
291 |
+
def step(self, x, cache):
|
292 |
+
# x : (B, D)
|
293 |
+
# cache : (h, inputs)
|
294 |
+
# h : (B, ED, N)
|
295 |
+
# inputs : (B, ED, d_conv-1)
|
296 |
+
|
297 |
+
# y : (B, D)
|
298 |
+
# cache : (h, inputs)
|
299 |
+
|
300 |
+
h, inputs = cache
|
301 |
+
|
302 |
+
xz = self.in_proj(x) # (B, 2*ED)
|
303 |
+
x, z = xz.chunk(2, dim=1) # (B, ED), (B, ED)
|
304 |
+
|
305 |
+
# x branch
|
306 |
+
x_cache = x.unsqueeze(2)
|
307 |
+
x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] # (B, ED)
|
308 |
+
|
309 |
+
x = F.silu(x)
|
310 |
+
y, h = self.ssm_step(x, h)
|
311 |
+
|
312 |
+
# z branch
|
313 |
+
z = F.silu(z)
|
314 |
+
|
315 |
+
output = y * z
|
316 |
+
output = self.out_proj(output) # (B, D)
|
317 |
+
|
318 |
+
# prepare cache for next call
|
319 |
+
inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
|
320 |
+
cache = (h, inputs)
|
321 |
+
|
322 |
+
return output, cache
|
323 |
+
|
324 |
+
def ssm_step(self, x, h):
|
325 |
+
# x : (B, ED)
|
326 |
+
# h : (B, ED, N)
|
327 |
+
|
328 |
+
# y : (B, ED)
|
329 |
+
# h : (B, ED, N)
|
330 |
+
|
331 |
+
A = -torch.exp(self.A_log.float()) # (ED, N) # todo : ne pas le faire tout le temps, puisque c'est indépendant de la timestep
|
332 |
+
D = self.D.float()
|
333 |
+
# TODO remove .float()
|
334 |
+
|
335 |
+
deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
|
336 |
+
|
337 |
+
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, dt_rank), (B, N), (B, N)
|
338 |
+
delta = F.softplus(self.dt_proj(delta)) # (B, ED)
|
339 |
+
|
340 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, ED, N)
|
341 |
+
deltaB = delta.unsqueeze(-1) * B.unsqueeze(1) # (B, ED, N)
|
342 |
+
|
343 |
+
BX = deltaB * (x.unsqueeze(-1)) # (B, ED, N)
|
344 |
+
|
345 |
+
if h is None:
|
346 |
+
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
|
347 |
+
|
348 |
+
h = deltaA * h + BX # (B, ED, N)
|
349 |
+
|
350 |
+
y = (h @ C.unsqueeze(-1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)
|
351 |
+
|
352 |
+
y = y + D * x
|
353 |
+
|
354 |
+
# todo : pq h.squeeze(1) ??
|
355 |
+
return y, h.squeeze(1)
|
356 |
+
|
357 |
+
# taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
|
358 |
+
class RMSNorm(nn.Module):
|
359 |
+
def __init__(self, d_model: int, eps: float = 1e-5):
|
360 |
+
super().__init__()
|
361 |
+
|
362 |
+
self.eps = eps
|
363 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
364 |
+
|
365 |
+
def forward(self, x):
|
366 |
+
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
367 |
+
|
368 |
+
return output
|
chess-gpt-eval/mamba/out/meta.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1121191e401988851de5744fe27fe463c3a086fc8c9a5538ef7fc12162bfb09
|
3 |
+
size 373
|
chess-gpt-eval/mamba_lm.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, fields, asdict
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from mamba import Mamba, MambaConfig, RMSNorm
|
9 |
+
|
10 |
+
"""
|
11 |
+
|
12 |
+
Encapsulates a Mamba model as language model. It has an embedding layer, and a LM head which maps the model output to logits.
|
13 |
+
|
14 |
+
"""
|
15 |
+
|
16 |
+
# TODO generate function : batch size != 1 ? (for now B=1)
|
17 |
+
# TODO generate function : top-p sampling
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class MambaLMConfig(MambaConfig):
|
21 |
+
vocab_size: int = 32000
|
22 |
+
pad_vocab_size_multiple: int = 8
|
23 |
+
|
24 |
+
def __post_init__(self):
|
25 |
+
super().__post_init__()
|
26 |
+
|
27 |
+
#if self.vocab_size % self.pad_vocab_size_multiple != 0:
|
28 |
+
# self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)
|
29 |
+
|
30 |
+
def to_mamba_config(self) -> MambaConfig:
|
31 |
+
mamba_config_fields = {field.name for field in fields(MambaConfig)}
|
32 |
+
filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields}
|
33 |
+
return MambaConfig(**filtered_dict)
|
34 |
+
|
35 |
+
# adapted from https://github.com/johnma2006/mamba-minimal
|
36 |
+
def from_pretrained(name: str):
|
37 |
+
"""
|
38 |
+
Returns a model loaded with pretrained weights pulled from HuggingFace.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
name: As of now, supports
|
42 |
+
* 'state-spaces/mamba-2.8b-slimpj'
|
43 |
+
* 'state-spaces/mamba-2.8b'
|
44 |
+
* 'state-spaces/mamba-1.4b'
|
45 |
+
* 'state-spaces/mamba-790m'
|
46 |
+
* 'state-spaces/mamba-370m'
|
47 |
+
* 'state-spaces/mamba-130m'
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
model: a Mamba model configured with the proper parameters and initialized with the proper weights
|
51 |
+
"""
|
52 |
+
|
53 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
54 |
+
from transformers.utils.hub import cached_file
|
55 |
+
|
56 |
+
def load_config_hf(model_name):
|
57 |
+
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
58 |
+
return json.load(open(resolved_archive_file))
|
59 |
+
|
60 |
+
def load_state_dict_hf(model_name):
|
61 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
62 |
+
return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
|
63 |
+
|
64 |
+
# copy config data
|
65 |
+
config_data = load_config_hf(name)
|
66 |
+
config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
|
67 |
+
|
68 |
+
model = MambaLM(config)
|
69 |
+
|
70 |
+
# copy weights
|
71 |
+
state_dict = load_state_dict_hf(name)
|
72 |
+
|
73 |
+
new_state_dict = {}
|
74 |
+
for key in state_dict:
|
75 |
+
if key == 'backbone.embedding.weight' or key == 'backbone.norm_f.weight':
|
76 |
+
new_key = key.replace('backbone.', '')
|
77 |
+
else:
|
78 |
+
new_key = key.replace('backbone', 'mamba')
|
79 |
+
|
80 |
+
new_state_dict[new_key] = state_dict[key]
|
81 |
+
|
82 |
+
model.load_state_dict(new_state_dict)
|
83 |
+
|
84 |
+
return model
|
85 |
+
|
86 |
+
class MambaLM(nn.Module):
|
87 |
+
def __init__(self, lm_config: MambaLMConfig):
|
88 |
+
super().__init__()
|
89 |
+
self.lm_config = lm_config
|
90 |
+
self.config = lm_config.to_mamba_config()
|
91 |
+
|
92 |
+
self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
|
93 |
+
self.mamba = Mamba(self.config)
|
94 |
+
self.norm_f = RMSNorm(self.config.d_model)
|
95 |
+
|
96 |
+
self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False)
|
97 |
+
self.lm_head.weight = self.embedding.weight
|
98 |
+
|
99 |
+
def forward(self, tokens):
|
100 |
+
# tokens : (B, L)
|
101 |
+
|
102 |
+
# logits : (B, L, vocab_size)
|
103 |
+
|
104 |
+
x = self.embedding(tokens)
|
105 |
+
|
106 |
+
x = self.mamba(x)
|
107 |
+
x = self.norm_f(x)
|
108 |
+
|
109 |
+
logits = self.lm_head(x)
|
110 |
+
|
111 |
+
return logits
|
112 |
+
|
113 |
+
def step(self, token, caches):
|
114 |
+
# token : (B)
|
115 |
+
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
116 |
+
|
117 |
+
# logits : (B, vocab_size)
|
118 |
+
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
119 |
+
|
120 |
+
x = self.embedding(token)
|
121 |
+
|
122 |
+
x, caches = self.mamba.step(x, caches)
|
123 |
+
x = self.norm_f(x)
|
124 |
+
|
125 |
+
logits = self.lm_head(x)
|
126 |
+
|
127 |
+
return logits, caches
|
128 |
+
|
129 |
+
# TODO temperature
|
130 |
+
# TODO process prompt in parallel, and pass in sequential mode when prompt is finished ?
|
131 |
+
def generate(self, tokenizer, prompt: str, num_tokens: int = 50, sample: bool = True, top_k: int = 40):
|
132 |
+
self.eval()
|
133 |
+
|
134 |
+
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(next(self.parameters()).device) # (1, num_tokens)
|
135 |
+
|
136 |
+
# caches is a list of cache, one per layer
|
137 |
+
# cache is composed of : the hidden state, and the last d_conv-1 inputs
|
138 |
+
# the hidden state because the update is like an RNN
|
139 |
+
# the last d_conv-1 inputs because they are used in a 1d convolution (usually d_conv=4 so this is not large)
|
140 |
+
caches = [(None, torch.zeros(1, self.config.d_inner, self.config.d_conv-1, device=input_ids.device)) for _ in range(self.config.n_layers)]
|
141 |
+
|
142 |
+
for i in range(input_ids.size(1) + num_tokens - 1):
|
143 |
+
with torch.no_grad():
|
144 |
+
# forward the new output, get new cache
|
145 |
+
next_token_logits, caches = self.step(input_ids[:, i], caches) # (1, vocab_size), caches
|
146 |
+
|
147 |
+
# sample (no sampling when the prompt is being processed)
|
148 |
+
if i+1 >= input_ids.size(1):
|
149 |
+
probs = F.softmax(next_token_logits, dim=-1) # (1, vocab_size)
|
150 |
+
|
151 |
+
if top_k is not None:
|
152 |
+
values, _ = torch.topk(probs, k=top_k) # (1, k) ordered from lowest to biggest
|
153 |
+
probs[probs < values[:, -1, None]] = 0
|
154 |
+
probs = probs / probs.sum(axis=1, keepdims=True)
|
155 |
+
|
156 |
+
if sample:
|
157 |
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (1)
|
158 |
+
else:
|
159 |
+
next_token = torch.argmax(probs, dim=-1) # (1)
|
160 |
+
|
161 |
+
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)
|
162 |
+
|
163 |
+
output = [tokenizer.decode(output.tolist()) for output in input_ids][0]
|
164 |
+
|
165 |
+
self.train()
|
166 |
+
|
167 |
+
return output
|
168 |
+
|
chess-gpt-eval/mamba_module.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import torch
|
4 |
+
from mamba_lm import MambaLM, MambaLMConfig, from_pretrained
|
5 |
+
from contextlib import nullcontext
|
6 |
+
|
7 |
+
BASE_DIR = "mamba/"
|
8 |
+
|
9 |
+
class MambaPlayer:
|
10 |
+
def __init__(self, model_name: str, move_num_in_gamestate: bool=False):
|
11 |
+
self.model_name = model_name
|
12 |
+
self.move_num_in_gamestate = move_num_in_gamestate
|
13 |
+
# -----------------------------------------------------------------------------
|
14 |
+
|
15 |
+
init_from = "resume" # either 'resume' or a Mamba variant (e.g. 'state-spaces/mamba-1.4b')
|
16 |
+
out_dir = "out" # ignored if init_from is not 'resume'
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
#device = "cpu"
|
19 |
+
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
|
20 |
+
seed = 1337
|
21 |
+
compile = False # set to True if using PyTorch 2.0 and Mamba supports it
|
22 |
+
# -----------------------------------------------------------------------------
|
23 |
+
|
24 |
+
torch.manual_seed(seed)
|
25 |
+
torch.cuda.manual_seed(seed)
|
26 |
+
|
27 |
+
device_type = (
|
28 |
+
"cuda" if "cuda" in device else "cpu"
|
29 |
+
) # for later use in torch.autocast
|
30 |
+
ptdtype = {
|
31 |
+
"float32": torch.float32,
|
32 |
+
"bfloat16": torch.bfloat16,
|
33 |
+
"float16": torch.float16,
|
34 |
+
}[dtype]
|
35 |
+
ctx = (
|
36 |
+
nullcontext()
|
37 |
+
if device_type == "cpu"
|
38 |
+
else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
39 |
+
)
|
40 |
+
|
41 |
+
# Model initialization
|
42 |
+
if init_from == "resume":
|
43 |
+
#ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
|
44 |
+
ckpt_path = os.path.normpath(f"../../mamba.py/out/{self.model_name}")
|
45 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
46 |
+
model_config = checkpoint["model_args"]
|
47 |
+
model = MambaLM(model_config)
|
48 |
+
model.load_state_dict(checkpoint['model'])
|
49 |
+
elif init_from.startswith('state-spaces'):
|
50 |
+
model = from_pretrained(init_from).to(device)
|
51 |
+
else:
|
52 |
+
raise ValueError("Invalid init_from value")
|
53 |
+
|
54 |
+
model.eval()
|
55 |
+
model.to(device)
|
56 |
+
|
57 |
+
if compile and hasattr(torch, 'compile'):
|
58 |
+
model = torch.compile(model)
|
59 |
+
|
60 |
+
# look for the meta pickle in case it is available in the dataset folder
|
61 |
+
meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
|
62 |
+
load_meta = os.path.exists(meta_path)
|
63 |
+
if move_num_in_gamestate and load_meta:
|
64 |
+
with open(meta_path, "rb") as f:
|
65 |
+
meta = pickle.load(f)
|
66 |
+
stoi, itos = meta["stoi"], meta["itos"]
|
67 |
+
vocab_size = meta['vocab_size']
|
68 |
+
encode = lambda s: [stoi[c] for c in s]
|
69 |
+
decode = lambda l: "".join([itos[i] for i in l])
|
70 |
+
else:
|
71 |
+
stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
|
72 |
+
itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
|
73 |
+
for s in stoi:
|
74 |
+
assert itos[stoi[s]] == s
|
75 |
+
vocab_size = len(stoi)
|
76 |
+
print(f"Vocab size {vocab_size}")
|
77 |
+
encode = lambda s: [stoi[c] for c in s.replace('-', '')]
|
78 |
+
decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O")
|
79 |
+
|
80 |
+
self.vocab_size = vocab_size
|
81 |
+
self.encode = encode
|
82 |
+
self.decode = decode
|
83 |
+
self.model = model
|
84 |
+
self.ctx = ctx
|
85 |
+
self.device = device
|
86 |
+
|
87 |
+
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
88 |
+
game_state = game_state.split("\n\n")[-1].strip()
|
89 |
+
#game_state = ";" + game_state
|
90 |
+
|
91 |
+
# Tokenize the game state
|
92 |
+
encoded_prompt = self.encode(game_state)
|
93 |
+
input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)
|
94 |
+
|
95 |
+
self.model.eval() # Set the model to evaluation mode
|
96 |
+
with torch.no_grad():
|
97 |
+
have_non_space = False
|
98 |
+
for _ in range(max_new_tokens):
|
99 |
+
logits = self.model(input_ids)[0, -1, :] # Get logits for the last token
|
100 |
+
|
101 |
+
# Apply temperature scaling and optionally sample from top k tokens
|
102 |
+
logits = logits / temperature
|
103 |
+
if top_k > 0:
|
104 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
105 |
+
logits[indices_to_remove] = -float('Inf')
|
106 |
+
|
107 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
108 |
+
next_token_id = torch.multinomial(probs, num_samples=1)
|
109 |
+
if have_non_space and (next_token_id == 0 or next_token_id==4):
|
110 |
+
break
|
111 |
+
else:
|
112 |
+
have_non_space = True
|
113 |
+
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
|
114 |
+
|
115 |
+
model_response = self.decode(input_ids[0].tolist())
|
116 |
+
model_response = model_response[len(game_state):].split(";")[0]
|
117 |
+
return model_response
|
118 |
+
|
119 |
+
#def encode(self, text: str):
|
120 |
+
# Implement the appropriate tokenization for MambaLM
|
121 |
+
# This could be a simple mapping or a more complex tokenizer
|
122 |
+
# return [stoi[char] for char in text] # Example
|
123 |
+
|
124 |
+
#def decode(self, token_ids: list):
|
125 |
+
# Implement the appropriate decoding for MambaLM
|
126 |
+
# return ''.join([itos[id] for id in token_ids]) # Example
|
127 |
+
|
128 |
+
def get_move_from_response(self, response: str) -> str:
|
129 |
+
if not response:
|
130 |
+
return None
|
131 |
+
# Parse the response to get only the first move
|
132 |
+
moves = response.split()
|
133 |
+
first_move = moves[0]
|
134 |
+
first_move = first_move.lstrip('.') # A patch for a weird phase during training ... doesn't seem to be an issue anymore, but don't see the harm.
|
135 |
+
|
136 |
+
return first_move
|
137 |
+
|
138 |
+
def get_move(self, board: str, game_state: str, temperature: float) -> str:
|
139 |
+
completion = self.get_mamba_response(game_state, temperature, 8, self.vocab_size)
|
140 |
+
return self.get_move_from_response(completion)
|
141 |
+
|
142 |
+
def get_config(self) -> dict:
|
143 |
+
return {"model": self.model_name}
|
144 |
+
|
chess-gpt-eval/nanogpt/__pycache__/model.cpython-310.pyc
ADDED
Binary file (12.5 kB). View file
|
|
chess-gpt-eval/nanogpt/__pycache__/nanogpt_module.cpython-310.pyc
ADDED
Binary file (5.19 kB). View file
|
|
chess-gpt-eval/nanogpt/__pycache__/xformer.cpython-310.pyc
ADDED
Binary file (12.5 kB). View file
|
|
chess-gpt-eval/nanogpt/configurator.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Poor Man's Configurator. Probably a terrible idea. Example usage:
|
3 |
+
$ python train.py config/override_file.py --batch_size=32
|
4 |
+
this will first run config/override_file.py, then override batch_size to 32
|
5 |
+
|
6 |
+
The code in this file will be run as follows from e.g. train.py:
|
7 |
+
>>> exec(open('configurator.py').read())
|
8 |
+
|
9 |
+
So it's not a Python module, it's just shuttling this code away from train.py
|
10 |
+
The code in this script then overrides the globals()
|
11 |
+
|
12 |
+
I know people are not going to love this, I just really dislike configuration
|
13 |
+
complexity and having to prepend config. to every single variable. If someone
|
14 |
+
comes up with a better simple Python solution I am all ears.
|
15 |
+
"""
|
16 |
+
|
17 |
+
import sys
|
18 |
+
from ast import literal_eval
|
19 |
+
|
20 |
+
for arg in sys.argv[1:]:
|
21 |
+
if '=' not in arg:
|
22 |
+
# assume it's the name of a config file
|
23 |
+
assert not arg.startswith('--')
|
24 |
+
config_file = arg
|
25 |
+
print(f"Overriding config with {config_file}:")
|
26 |
+
with open(config_file) as f:
|
27 |
+
print(f.read())
|
28 |
+
exec(open(config_file).read())
|
29 |
+
else:
|
30 |
+
# assume it's a --key=value argument
|
31 |
+
assert arg.startswith('--')
|
32 |
+
key, val = arg.split('=')
|
33 |
+
key = key[2:]
|
34 |
+
if key in globals():
|
35 |
+
try:
|
36 |
+
# attempt to eval it it (e.g. if bool, number, or etc)
|
37 |
+
attempt = literal_eval(val)
|
38 |
+
except (SyntaxError, ValueError):
|
39 |
+
# if that goes wrong, just use the string
|
40 |
+
attempt = val
|
41 |
+
# ensure the types match ok
|
42 |
+
assert type(attempt) == type(globals()[key])
|
43 |
+
# cross fingers
|
44 |
+
print(f"Overriding: {key} = {attempt}")
|
45 |
+
globals()[key] = attempt
|
46 |
+
else:
|
47 |
+
raise ValueError(f"Unknown config key: {key}")
|
chess-gpt-eval/nanogpt/model.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Full definition of a GPT Language Model, all of it in this single file.
|
3 |
+
References:
|
4 |
+
1) the official GPT-2 TensorFlow implementation released by OpenAI:
|
5 |
+
https://github.com/openai/gpt-2/blob/master/src/model.py
|
6 |
+
2) huggingface/transformers PyTorch implementation:
|
7 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
8 |
+
"""
|
9 |
+
|
10 |
+
import math
|
11 |
+
import inspect
|
12 |
+
from dataclasses import dataclass
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
class LayerNorm(nn.Module):
|
19 |
+
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
20 |
+
|
21 |
+
def __init__(self, ndim, bias):
|
22 |
+
super().__init__()
|
23 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
24 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
25 |
+
|
26 |
+
def forward(self, input):
|
27 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
28 |
+
|
29 |
+
class CausalSelfAttention(nn.Module):
|
30 |
+
|
31 |
+
def __init__(self, config):
|
32 |
+
super().__init__()
|
33 |
+
assert config.n_embd % config.n_head == 0
|
34 |
+
# key, query, value projections for all heads, but in a batch
|
35 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
36 |
+
# output projection
|
37 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
38 |
+
# regularization
|
39 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
40 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
41 |
+
self.n_head = config.n_head
|
42 |
+
self.n_embd = config.n_embd
|
43 |
+
self.dropout = config.dropout
|
44 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
45 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
46 |
+
if not self.flash:
|
47 |
+
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
48 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
49 |
+
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
50 |
+
.view(1, 1, config.block_size, config.block_size))
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
54 |
+
|
55 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
56 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
57 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
58 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
59 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
60 |
+
|
61 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
62 |
+
if self.flash:
|
63 |
+
# efficient attention using Flash Attention CUDA kernels
|
64 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
|
65 |
+
else:
|
66 |
+
# manual implementation of attention
|
67 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
68 |
+
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
69 |
+
att = F.softmax(att, dim=-1)
|
70 |
+
att = self.attn_dropout(att)
|
71 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
72 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
73 |
+
|
74 |
+
# output projection
|
75 |
+
y = self.resid_dropout(self.c_proj(y))
|
76 |
+
return y
|
77 |
+
|
78 |
+
class MLP(nn.Module):
|
79 |
+
|
80 |
+
def __init__(self, config):
|
81 |
+
super().__init__()
|
82 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
83 |
+
self.gelu = nn.GELU()
|
84 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
85 |
+
self.dropout = nn.Dropout(config.dropout)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x = self.c_fc(x)
|
89 |
+
x = self.gelu(x)
|
90 |
+
x = self.c_proj(x)
|
91 |
+
x = self.dropout(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
class Block(nn.Module):
|
95 |
+
|
96 |
+
def __init__(self, config):
|
97 |
+
super().__init__()
|
98 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
99 |
+
self.attn = CausalSelfAttention(config)
|
100 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
101 |
+
self.mlp = MLP(config)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
x = x + self.attn(self.ln_1(x))
|
105 |
+
x = x + self.mlp(self.ln_2(x))
|
106 |
+
return x
|
107 |
+
|
108 |
+
@dataclass
|
109 |
+
class GPTConfig:
|
110 |
+
block_size: int = 1024
|
111 |
+
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
|
112 |
+
n_layer: int = 12
|
113 |
+
n_head: int = 12
|
114 |
+
n_embd: int = 768
|
115 |
+
dropout: float = 0.0
|
116 |
+
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
117 |
+
|
118 |
+
class GPT(nn.Module):
|
119 |
+
|
120 |
+
def __init__(self, config):
|
121 |
+
super().__init__()
|
122 |
+
assert config.vocab_size is not None
|
123 |
+
assert config.block_size is not None
|
124 |
+
self.config = config
|
125 |
+
|
126 |
+
self.transformer = nn.ModuleDict(dict(
|
127 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
128 |
+
wpe = nn.Embedding(config.block_size, config.n_embd),
|
129 |
+
drop = nn.Dropout(config.dropout),
|
130 |
+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
131 |
+
ln_f = LayerNorm(config.n_embd, bias=config.bias),
|
132 |
+
))
|
133 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
134 |
+
# with weight tying when using torch.compile() some warnings get generated:
|
135 |
+
# "UserWarning: functional_call was passed multiple values for tied weights.
|
136 |
+
# This behavior is deprecated and will be an error in future versions"
|
137 |
+
# not 100% sure what this is, so far seems to be harmless. TODO investigate
|
138 |
+
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
139 |
+
|
140 |
+
# init all weights
|
141 |
+
self.apply(self._init_weights)
|
142 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
143 |
+
for pn, p in self.named_parameters():
|
144 |
+
if pn.endswith('c_proj.weight'):
|
145 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
146 |
+
|
147 |
+
# report number of parameters
|
148 |
+
#print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
|
149 |
+
|
150 |
+
def get_num_params(self, non_embedding=True):
|
151 |
+
"""
|
152 |
+
Return the number of parameters in the model.
|
153 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
154 |
+
The token embeddings would too, except due to the parameter sharing these
|
155 |
+
params are actually used as weights in the final layer, so we include them.
|
156 |
+
"""
|
157 |
+
n_params = sum(p.numel() for p in self.parameters())
|
158 |
+
if non_embedding:
|
159 |
+
n_params -= self.transformer.wpe.weight.numel()
|
160 |
+
return n_params
|
161 |
+
|
162 |
+
def _init_weights(self, module):
|
163 |
+
if isinstance(module, nn.Linear):
|
164 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
165 |
+
if module.bias is not None:
|
166 |
+
torch.nn.init.zeros_(module.bias)
|
167 |
+
elif isinstance(module, nn.Embedding):
|
168 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
169 |
+
|
170 |
+
def forward(self, idx, targets=None):
|
171 |
+
device = idx.device
|
172 |
+
b, t = idx.size()
|
173 |
+
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
174 |
+
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
|
175 |
+
|
176 |
+
# forward the GPT model itself
|
177 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
178 |
+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
179 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
180 |
+
for block in self.transformer.h:
|
181 |
+
x = block(x)
|
182 |
+
x = self.transformer.ln_f(x)
|
183 |
+
|
184 |
+
if targets is not None:
|
185 |
+
# if we are given some desired targets also calculate the loss
|
186 |
+
logits = self.lm_head(x)
|
187 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
188 |
+
else:
|
189 |
+
# inference-time mini-optimization: only forward the lm_head on the very last position
|
190 |
+
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
191 |
+
loss = None
|
192 |
+
|
193 |
+
return logits, loss
|
194 |
+
|
195 |
+
def crop_block_size(self, block_size):
|
196 |
+
# model surgery to decrease the block size if necessary
|
197 |
+
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
|
198 |
+
# but want to use a smaller block size for some smaller, simpler model
|
199 |
+
assert block_size <= self.config.block_size
|
200 |
+
self.config.block_size = block_size
|
201 |
+
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
|
202 |
+
for block in self.transformer.h:
|
203 |
+
if hasattr(block.attn, 'bias'):
|
204 |
+
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
|
205 |
+
|
206 |
+
@classmethod
|
207 |
+
def from_pretrained(cls, model_type, override_args=None):
|
208 |
+
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
|
209 |
+
override_args = override_args or {} # default to empty dict
|
210 |
+
# only dropout can be overridden see more notes below
|
211 |
+
assert all(k == 'dropout' for k in override_args)
|
212 |
+
from transformers import GPT2LMHeadModel
|
213 |
+
print("loading weights from pretrained gpt: %s" % model_type)
|
214 |
+
|
215 |
+
# n_layer, n_head and n_embd are determined from model_type
|
216 |
+
config_args = {
|
217 |
+
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
|
218 |
+
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
|
219 |
+
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
|
220 |
+
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
|
221 |
+
}[model_type]
|
222 |
+
print("forcing vocab_size=50257, block_size=1024, bias=True")
|
223 |
+
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
|
224 |
+
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
|
225 |
+
config_args['bias'] = True # always True for GPT model checkpoints
|
226 |
+
# we can override the dropout rate, if desired
|
227 |
+
if 'dropout' in override_args:
|
228 |
+
print(f"overriding dropout rate to {override_args['dropout']}")
|
229 |
+
config_args['dropout'] = override_args['dropout']
|
230 |
+
# create a from-scratch initialized minGPT model
|
231 |
+
config = GPTConfig(**config_args)
|
232 |
+
model = GPT(config)
|
233 |
+
sd = model.state_dict()
|
234 |
+
sd_keys = sd.keys()
|
235 |
+
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
|
236 |
+
|
237 |
+
# init a huggingface/transformers model
|
238 |
+
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
239 |
+
sd_hf = model_hf.state_dict()
|
240 |
+
|
241 |
+
# copy while ensuring all of the parameters are aligned and match in names and shapes
|
242 |
+
sd_keys_hf = sd_hf.keys()
|
243 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
|
244 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
|
245 |
+
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
246 |
+
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
|
247 |
+
# this means that we have to transpose these weights when we import them
|
248 |
+
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
|
249 |
+
for k in sd_keys_hf:
|
250 |
+
if any(k.endswith(w) for w in transposed):
|
251 |
+
# special treatment for the Conv1D weights we need to transpose
|
252 |
+
assert sd_hf[k].shape[::-1] == sd[k].shape
|
253 |
+
with torch.no_grad():
|
254 |
+
sd[k].copy_(sd_hf[k].t())
|
255 |
+
else:
|
256 |
+
# vanilla copy over the other parameters
|
257 |
+
assert sd_hf[k].shape == sd[k].shape
|
258 |
+
with torch.no_grad():
|
259 |
+
sd[k].copy_(sd_hf[k])
|
260 |
+
|
261 |
+
return model
|
262 |
+
|
263 |
+
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
264 |
+
# start with all of the candidate parameters
|
265 |
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
266 |
+
# filter out those that do not require grad
|
267 |
+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
268 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
269 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
270 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
271 |
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
272 |
+
optim_groups = [
|
273 |
+
{'params': decay_params, 'weight_decay': weight_decay},
|
274 |
+
{'params': nodecay_params, 'weight_decay': 0.0}
|
275 |
+
]
|
276 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
277 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
278 |
+
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
279 |
+
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
280 |
+
# Create AdamW optimizer and use the fused version if it is available
|
281 |
+
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
282 |
+
use_fused = fused_available and device_type == 'cuda'
|
283 |
+
extra_args = dict(fused=True) if use_fused else dict()
|
284 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
285 |
+
print(f"using fused AdamW: {use_fused}")
|
286 |
+
|
287 |
+
return optimizer
|
288 |
+
|
289 |
+
def estimate_mfu(self, fwdbwd_per_iter, dt):
|
290 |
+
""" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
|
291 |
+
# first estimate the number of flops we do per iteration.
|
292 |
+
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
|
293 |
+
N = self.get_num_params()
|
294 |
+
cfg = self.config
|
295 |
+
L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
|
296 |
+
flops_per_token = 6*N + 12*L*H*Q*T
|
297 |
+
flops_per_fwdbwd = flops_per_token * T
|
298 |
+
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
299 |
+
# express our flops throughput as ratio of A100 bfloat16 peak flops
|
300 |
+
flops_achieved = flops_per_iter * (1.0/dt) # per second
|
301 |
+
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
302 |
+
mfu = flops_achieved / flops_promised
|
303 |
+
return mfu
|
304 |
+
|
305 |
+
@torch.no_grad()
|
306 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
307 |
+
"""
|
308 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
309 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
310 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
311 |
+
"""
|
312 |
+
for _ in range(max_new_tokens):
|
313 |
+
# if the sequence context is growing too long we must crop it at block_size
|
314 |
+
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
315 |
+
# forward the model to get the logits for the index in the sequence
|
316 |
+
logits, _ = self(idx_cond)
|
317 |
+
# pluck the logits at the final step and scale by desired temperature
|
318 |
+
logits = logits[:, -1, :] / temperature
|
319 |
+
# optionally crop the logits to only the top k options
|
320 |
+
if top_k is not None:
|
321 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
322 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
323 |
+
# apply softmax to convert logits to (normalized) probabilities
|
324 |
+
probs = F.softmax(logits, dim=-1)
|
325 |
+
# sample from the distribution
|
326 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
327 |
+
# append sampled index to the running sequence and continue
|
328 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
329 |
+
|
330 |
+
return idx
|
chess-gpt-eval/nanogpt/nanogpt_module.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Sample from a trained model
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
from contextlib import nullcontext
|
7 |
+
import torch
|
8 |
+
import tiktoken
|
9 |
+
from nanogpt.model import GPTConfig, GPT
|
10 |
+
|
11 |
+
BASE_DIR = "nanogpt/"
|
12 |
+
|
13 |
+
|
14 |
+
class NanoGptPlayer:
|
15 |
+
def __init__(self, model_name: str, move_num_in_gamestate: bool=False):
|
16 |
+
self.model_name = model_name
|
17 |
+
# -----------------------------------------------------------------------------
|
18 |
+
|
19 |
+
init_from = "resume" # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
|
20 |
+
out_dir = "out" # ignored if init_from is not 'resume'
|
21 |
+
input_dir = "addition"
|
22 |
+
test_name = "test.txt"
|
23 |
+
start = "12+44=" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
|
24 |
+
num_samples = 1 # number of samples to draw
|
25 |
+
max_new_tokens = 6 # number of tokens generated in each sample
|
26 |
+
temperature = 0.01 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
|
27 |
+
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
|
28 |
+
seed = 1337
|
29 |
+
device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
|
30 |
+
#device = "cpu"
|
31 |
+
dtype = "float16" # 'float32' or 'bfloat16' or 'float16'
|
32 |
+
compile = False # use PyTorch 2.0 to compile the model to be faster
|
33 |
+
exec(
|
34 |
+
open(f"{BASE_DIR}configurator.py").read()
|
35 |
+
) # overrides from command line or config file
|
36 |
+
# -----------------------------------------------------------------------------
|
37 |
+
|
38 |
+
torch.manual_seed(seed)
|
39 |
+
torch.cuda.manual_seed(seed)
|
40 |
+
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
41 |
+
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
42 |
+
device_type = (
|
43 |
+
"cuda" if "cuda" in device else "cpu"
|
44 |
+
) # for later use in torch.autocast
|
45 |
+
ptdtype = {
|
46 |
+
"float32": torch.float32,
|
47 |
+
"bfloat16": torch.bfloat16,
|
48 |
+
"float16": torch.float16,
|
49 |
+
}[dtype]
|
50 |
+
ctx = (
|
51 |
+
nullcontext()
|
52 |
+
if device_type == "cpu"
|
53 |
+
else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
54 |
+
)
|
55 |
+
|
56 |
+
# model
|
57 |
+
if init_from == "resume":
|
58 |
+
# init from a model saved in a specific directory
|
59 |
+
#ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
|
60 |
+
ckpt_path = os.path.normpath(f"../../mamba.py/out/{self.model_name}")
|
61 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
62 |
+
#gptconf = GPTConfig(**checkpoint["model_args"])
|
63 |
+
#model = GPT(gptconf)
|
64 |
+
model = GPT(checkpoint["model_args"])
|
65 |
+
state_dict = checkpoint["model"]
|
66 |
+
unwanted_prefix = "_orig_mod."
|
67 |
+
for k, v in list(state_dict.items()):
|
68 |
+
if k.startswith(unwanted_prefix):
|
69 |
+
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
|
70 |
+
model.load_state_dict(state_dict)
|
71 |
+
elif init_from.startswith("gpt2"):
|
72 |
+
# init from a given GPT-2 model
|
73 |
+
model = GPT.from_pretrained(init_from, dict(dropout=0.0))
|
74 |
+
|
75 |
+
model.eval()
|
76 |
+
model.to(device)
|
77 |
+
if compile:
|
78 |
+
model = torch.compile(model) # requires PyTorch 2.0 (optional)
|
79 |
+
|
80 |
+
# look for the meta pickle in case it is available in the dataset folder
|
81 |
+
meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
|
82 |
+
load_meta = os.path.exists(meta_path)
|
83 |
+
if move_num_in_gamestate and load_meta:
|
84 |
+
with open(meta_path, "rb") as f:
|
85 |
+
meta = pickle.load(f)
|
86 |
+
stoi, itos = meta["stoi"], meta["itos"]
|
87 |
+
vocab_size = meta['vocab_size']
|
88 |
+
encode = lambda s: [stoi[c] for c in s]
|
89 |
+
decode = lambda l: "".join([itos[i] for i in l])
|
90 |
+
else:
|
91 |
+
stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
|
92 |
+
itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
|
93 |
+
for s in stoi:
|
94 |
+
assert itos[stoi[s]] == s
|
95 |
+
vocab_size = len(stoi)
|
96 |
+
print(f"Vocab size {vocab_size}")
|
97 |
+
encode = lambda s: [stoi[c] for c in s.replace('-', '')]
|
98 |
+
decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O")
|
99 |
+
|
100 |
+
self.encode = encode
|
101 |
+
self.decode = decode
|
102 |
+
self.model = model
|
103 |
+
self.ctx = ctx
|
104 |
+
self.device = device
|
105 |
+
|
106 |
+
def get_nanogpt_response(self, game_state: str, temperature: float) -> str:
|
107 |
+
num_samples = 1 # number of samples to draw
|
108 |
+
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
|
109 |
+
max_new_tokens = 8
|
110 |
+
|
111 |
+
# Remove ["stockfish elo xxx"]\n["stockfish elo xxx"]\n\n from game_state
|
112 |
+
# nanogpt was trained only on pgn transcripts
|
113 |
+
game_state = game_state.split("\n\n")[-1].strip()
|
114 |
+
|
115 |
+
# print("game_state", game_state)
|
116 |
+
|
117 |
+
#game_state = ";" + game_state
|
118 |
+
|
119 |
+
start_ids = self.encode(game_state)
|
120 |
+
|
121 |
+
x = torch.tensor(start_ids, dtype=torch.long, device=self.device)[None, ...]
|
122 |
+
with torch.no_grad():
|
123 |
+
with self.ctx:
|
124 |
+
for k in range(num_samples):
|
125 |
+
y = self.model.generate(
|
126 |
+
x, max_new_tokens, temperature=temperature, top_k=top_k
|
127 |
+
)
|
128 |
+
|
129 |
+
model_response = self.decode(y[0].tolist())
|
130 |
+
|
131 |
+
# print("model_response", model_response)
|
132 |
+
# model_response includes the input string
|
133 |
+
model_response = model_response[len(game_state):].split(";")[0]
|
134 |
+
return model_response
|
135 |
+
|
136 |
+
def get_move_from_response(self, response: str) -> str:
|
137 |
+
# Parse the response to get only the first move
|
138 |
+
moves = response.split()
|
139 |
+
first_move = moves[0]
|
140 |
+
|
141 |
+
return first_move
|
142 |
+
|
143 |
+
def get_move(self, board: str, game_state: str, temperature: float) -> str:
|
144 |
+
completion = self.get_nanogpt_response(game_state, temperature)
|
145 |
+
return self.get_move_from_response(completion)
|
146 |
+
|
147 |
+
def get_config(self) -> dict:
|
148 |
+
return {"model": self.model_name}
|
chess-gpt-eval/nanogpt/out/meta.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1121191e401988851de5744fe27fe463c3a086fc8c9a5538ef7fc12162bfb09
|
3 |
+
size 373
|
chess-gpt-eval/nanogpt/out/view_ckpt.ipynb
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"\n",
|
11 |
+
"def load_checkpoint(filepath: str) -> dict:\n",
|
12 |
+
" \"\"\"\n",
|
13 |
+
" Load a checkpoint file.\n",
|
14 |
+
"\n",
|
15 |
+
" Args:\n",
|
16 |
+
" filepath (str): Path to the .ckpt file.\n",
|
17 |
+
"\n",
|
18 |
+
" Returns:\n",
|
19 |
+
" dict: Contents of the checkpoint file.\n",
|
20 |
+
" \"\"\"\n",
|
21 |
+
" checkpoint = torch.load(filepath, map_location=torch.device('cpu'))\n",
|
22 |
+
" return checkpoint\n",
|
23 |
+
"\n",
|
24 |
+
"checkpoint_path = 'ckpt.pt'\n",
|
25 |
+
"checkpoint_data = load_checkpoint(checkpoint_path)\n",
|
26 |
+
"\n",
|
27 |
+
"# Print the keys to understand what's inside\n",
|
28 |
+
"print(checkpoint_data.keys())\n",
|
29 |
+
"\n",
|
30 |
+
"# If you want to view specific information, access it using the keys\n",
|
31 |
+
"# For example, to view the model's state_dict\n",
|
32 |
+
"model_state = checkpoint_data.get('state_dict', None)\n",
|
33 |
+
"if model_state:\n",
|
34 |
+
" print(\"Model's state dict:\", model_state)\n",
|
35 |
+
"\n",
|
36 |
+
"# To view training information like current learning rate, iterations, etc.\n",
|
37 |
+
"training_info = checkpoint_data.get('training_info', None)\n",
|
38 |
+
"if training_info:\n",
|
39 |
+
" print(\"Training Info:\", training_info)\n",
|
40 |
+
"\n",
|
41 |
+
"# To view config, if it's stored in the checkpoint\n",
|
42 |
+
"config = checkpoint_data.get('config', None)\n",
|
43 |
+
"if config:\n",
|
44 |
+
" print(\"Configurations:\", config)\n"
|
45 |
+
]
|
46 |
+
}
|
47 |
+
],
|
48 |
+
"metadata": {
|
49 |
+
"kernelspec": {
|
50 |
+
"display_name": "openai",
|
51 |
+
"language": "python",
|
52 |
+
"name": "python3"
|
53 |
+
},
|
54 |
+
"language_info": {
|
55 |
+
"name": "python",
|
56 |
+
"version": "3.10.13"
|
57 |
+
}
|
58 |
+
},
|
59 |
+
"nbformat": 4,
|
60 |
+
"nbformat_minor": 2
|
61 |
+
}
|
chess-gpt-eval/openings.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
chess-gpt-eval/pscan.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
"""
|
7 |
+
|
8 |
+
An implementation of the parallel scan operation in PyTorch (Blelloch version).
|
9 |
+
Please see docs/pscan.ipynb for a detailed explanation of what happens here.
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
def npo2(len):
|
14 |
+
"""
|
15 |
+
Returns the next power of 2 above len
|
16 |
+
"""
|
17 |
+
|
18 |
+
return 2 ** math.ceil(math.log2(len))
|
19 |
+
|
20 |
+
def pad_npo2(X):
|
21 |
+
"""
|
22 |
+
Pads input length dim to the next power of 2
|
23 |
+
|
24 |
+
Args:
|
25 |
+
X : (B, L, D, N)
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Y : (B, npo2(L), D, N)
|
29 |
+
"""
|
30 |
+
|
31 |
+
len_npo2 = npo2(X.size(1))
|
32 |
+
pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
|
33 |
+
return F.pad(X, pad_tuple, "constant", 0)
|
34 |
+
|
35 |
+
class PScan(torch.autograd.Function):
|
36 |
+
@staticmethod
|
37 |
+
def pscan(A, X):
|
38 |
+
# A : (B, D, L, N)
|
39 |
+
# X : (B, D, L, N)
|
40 |
+
|
41 |
+
# modifies X in place by doing a parallel scan.
|
42 |
+
# more formally, X will be populated by these values :
|
43 |
+
# H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
|
44 |
+
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
|
45 |
+
|
46 |
+
# only supports L that is a power of two (mainly for a clearer code)
|
47 |
+
|
48 |
+
B, D, L, _ = A.size()
|
49 |
+
num_steps = int(math.log2(L))
|
50 |
+
|
51 |
+
# up sweep (last 2 steps unfolded)
|
52 |
+
Aa = A
|
53 |
+
Xa = X
|
54 |
+
for _ in range(num_steps-2):
|
55 |
+
T = Xa.size(2)
|
56 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
57 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
58 |
+
|
59 |
+
Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
|
60 |
+
Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
|
61 |
+
|
62 |
+
Aa = Aa[:, :, :, 1]
|
63 |
+
Xa = Xa[:, :, :, 1]
|
64 |
+
|
65 |
+
# we have only 4, 2 or 1 nodes left
|
66 |
+
if Xa.size(2) == 4:
|
67 |
+
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
|
68 |
+
Aa[:, :, 1].mul_(Aa[:, :, 0])
|
69 |
+
|
70 |
+
Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
|
71 |
+
elif Xa.size(2) == 2:
|
72 |
+
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
|
73 |
+
return
|
74 |
+
else:
|
75 |
+
return
|
76 |
+
|
77 |
+
# down sweep (first 2 steps unfolded)
|
78 |
+
Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
|
79 |
+
Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
|
80 |
+
Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
|
81 |
+
Aa[:, :, 2].mul_(Aa[:, :, 1])
|
82 |
+
|
83 |
+
for k in range(num_steps-3, -1, -1):
|
84 |
+
Aa = A[:, :, 2**k-1:L:2**k]
|
85 |
+
Xa = X[:, :, 2**k-1:L:2**k]
|
86 |
+
|
87 |
+
T = Xa.size(2)
|
88 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
89 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
90 |
+
|
91 |
+
Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
|
92 |
+
Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def pscan_rev(A, X):
|
96 |
+
# A : (B, D, L, N)
|
97 |
+
# X : (B, D, L, N)
|
98 |
+
|
99 |
+
# the same function as above, but in reverse
|
100 |
+
# (if you flip the input, call pscan, then flip the output, you get what this function outputs)
|
101 |
+
# it is used in the backward pass
|
102 |
+
|
103 |
+
# only supports L that is a power of two (mainly for a clearer code)
|
104 |
+
|
105 |
+
B, D, L, _ = A.size()
|
106 |
+
num_steps = int(math.log2(L))
|
107 |
+
|
108 |
+
# up sweep (last 2 steps unfolded)
|
109 |
+
Aa = A
|
110 |
+
Xa = X
|
111 |
+
for _ in range(num_steps-2):
|
112 |
+
T = Xa.size(2)
|
113 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
114 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
115 |
+
|
116 |
+
Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
|
117 |
+
Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
|
118 |
+
|
119 |
+
Aa = Aa[:, :, :, 0]
|
120 |
+
Xa = Xa[:, :, :, 0]
|
121 |
+
|
122 |
+
# we have only 4, 2 or 1 nodes left
|
123 |
+
if Xa.size(2) == 4:
|
124 |
+
Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
|
125 |
+
Aa[:, :, 2].mul_(Aa[:, :, 3])
|
126 |
+
|
127 |
+
Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
|
128 |
+
elif Xa.size(2) == 2:
|
129 |
+
Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
|
130 |
+
return
|
131 |
+
else:
|
132 |
+
return
|
133 |
+
|
134 |
+
# down sweep (first 2 steps unfolded)
|
135 |
+
Aa = A[:, :, 0:L:2**(num_steps-2)]
|
136 |
+
Xa = X[:, :, 0:L:2**(num_steps-2)]
|
137 |
+
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
|
138 |
+
Aa[:, :, 1].mul_(Aa[:, :, 2])
|
139 |
+
|
140 |
+
for k in range(num_steps-3, -1, -1):
|
141 |
+
Aa = A[:, :, 0:L:2**k]
|
142 |
+
Xa = X[:, :, 0:L:2**k]
|
143 |
+
|
144 |
+
T = Xa.size(2)
|
145 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
146 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
147 |
+
|
148 |
+
Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
|
149 |
+
Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def forward(ctx, A_in, X_in):
|
153 |
+
"""
|
154 |
+
Applies the parallel scan operation, as defined above. Returns a new tensor.
|
155 |
+
If you can, privilege sequence lengths that are powers of two.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
A_in : (B, L, D, N)
|
159 |
+
X_in : (B, L, D, N)
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
H : (B, L, D, N)
|
163 |
+
"""
|
164 |
+
|
165 |
+
L = X_in.size(1)
|
166 |
+
|
167 |
+
# cloning is requiered because of the in-place ops
|
168 |
+
if L == npo2(L):
|
169 |
+
A = A_in.clone()
|
170 |
+
X = X_in.clone()
|
171 |
+
else:
|
172 |
+
# pad tensors (and clone btw)
|
173 |
+
A = pad_npo2(A_in) # (B, npo2(L), D, N)
|
174 |
+
X = pad_npo2(X_in) # (B, npo2(L), D, N)
|
175 |
+
|
176 |
+
# prepare tensors
|
177 |
+
A = A.transpose(2, 1) # (B, D, npo2(L), N)
|
178 |
+
X = X.transpose(2, 1) # (B, D, npo2(L), N)
|
179 |
+
|
180 |
+
# parallel scan (modifies X in-place)
|
181 |
+
PScan.pscan(A, X)
|
182 |
+
|
183 |
+
ctx.save_for_backward(A_in, X)
|
184 |
+
|
185 |
+
# slice [:, :L] (cut if there was padding)
|
186 |
+
return X.transpose(2, 1)[:, :L]
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def backward(ctx, grad_output_in):
|
190 |
+
"""
|
191 |
+
Flows the gradient from the output to the input. Returns two new tensors.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
ctx : A_in : (B, L, D, N), X : (B, D, L, N)
|
195 |
+
grad_output_in : (B, L, D, N)
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
gradA : (B, L, D, N), gradX : (B, L, D, N)
|
199 |
+
"""
|
200 |
+
|
201 |
+
A_in, X = ctx.saved_tensors
|
202 |
+
|
203 |
+
L = grad_output_in.size(1)
|
204 |
+
|
205 |
+
# cloning is requiered because of the in-place ops
|
206 |
+
if L == npo2(L):
|
207 |
+
grad_output = grad_output_in.clone()
|
208 |
+
# the next padding will clone A_in
|
209 |
+
else:
|
210 |
+
grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
|
211 |
+
A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
|
212 |
+
|
213 |
+
# prepare tensors
|
214 |
+
grad_output = grad_output.transpose(2, 1)
|
215 |
+
A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
|
216 |
+
A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
|
217 |
+
|
218 |
+
# reverse parallel scan (modifies grad_output in-place)
|
219 |
+
PScan.pscan_rev(A, grad_output)
|
220 |
+
|
221 |
+
Q = torch.zeros_like(X)
|
222 |
+
Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
|
223 |
+
|
224 |
+
return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
|
225 |
+
|
226 |
+
pscan = PScan.apply
|
chess-gpt-eval/requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai==0.28.0
|
2 |
+
tiktoken==0.4.0
|
3 |
+
tenacity==8.2.3
|
4 |
+
python-chess==1.999
|
5 |
+
matplotlib==3.8.0
|
6 |
+
pandas==2.1.1
|
chess-gpt-eval/xformer.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Full definition of a GPT Language Model, all of it in this single file.
|
3 |
+
References:
|
4 |
+
1) the official GPT-2 TensorFlow implementation released by OpenAI:
|
5 |
+
https://github.com/openai/gpt-2/blob/master/src/model.py
|
6 |
+
2) huggingface/transformers PyTorch implementation:
|
7 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
8 |
+
"""
|
9 |
+
|
10 |
+
import math
|
11 |
+
import inspect
|
12 |
+
from dataclasses import dataclass
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
class LayerNorm(nn.Module):
|
19 |
+
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
20 |
+
|
21 |
+
def __init__(self, ndim, bias):
|
22 |
+
super().__init__()
|
23 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
24 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
25 |
+
|
26 |
+
def forward(self, input):
|
27 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
28 |
+
|
29 |
+
class CausalSelfAttention(nn.Module):
|
30 |
+
|
31 |
+
def __init__(self, config):
|
32 |
+
super().__init__()
|
33 |
+
assert config.n_embd % config.n_head == 0
|
34 |
+
# key, query, value projections for all heads, but in a batch
|
35 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
36 |
+
# output projection
|
37 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
38 |
+
# regularization
|
39 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
40 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
41 |
+
self.n_head = config.n_head
|
42 |
+
self.n_embd = config.n_embd
|
43 |
+
self.dropout = config.dropout
|
44 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
45 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
46 |
+
if not self.flash:
|
47 |
+
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
48 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
49 |
+
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
50 |
+
.view(1, 1, config.block_size, config.block_size))
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
54 |
+
|
55 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
56 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
57 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
58 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
59 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
60 |
+
|
61 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
62 |
+
if self.flash:
|
63 |
+
# efficient attention using Flash Attention CUDA kernels
|
64 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
|
65 |
+
else:
|
66 |
+
# manual implementation of attention
|
67 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
68 |
+
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
69 |
+
att = F.softmax(att, dim=-1)
|
70 |
+
att = self.attn_dropout(att)
|
71 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
72 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
73 |
+
|
74 |
+
# output projection
|
75 |
+
y = self.resid_dropout(self.c_proj(y))
|
76 |
+
return y
|
77 |
+
|
78 |
+
class MLP(nn.Module):
|
79 |
+
|
80 |
+
def __init__(self, config):
|
81 |
+
super().__init__()
|
82 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
83 |
+
self.gelu = nn.GELU()
|
84 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
85 |
+
self.dropout = nn.Dropout(config.dropout)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x = self.c_fc(x)
|
89 |
+
x = self.gelu(x)
|
90 |
+
x = self.c_proj(x)
|
91 |
+
x = self.dropout(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
class Block(nn.Module):
|
95 |
+
|
96 |
+
def __init__(self, config):
|
97 |
+
super().__init__()
|
98 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
99 |
+
self.attn = CausalSelfAttention(config)
|
100 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
101 |
+
self.mlp = MLP(config)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
x = x + self.attn(self.ln_1(x))
|
105 |
+
x = x + self.mlp(self.ln_2(x))
|
106 |
+
return x
|
107 |
+
|
108 |
+
@dataclass
|
109 |
+
class GPTConfig:
|
110 |
+
block_size: int = 1024
|
111 |
+
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
|
112 |
+
n_layer: int = 12
|
113 |
+
n_head: int = 12
|
114 |
+
n_embd: int = 768
|
115 |
+
dropout: float = 0.0
|
116 |
+
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
117 |
+
|
118 |
+
class GPT(nn.Module):
|
119 |
+
|
120 |
+
def __init__(self, config):
|
121 |
+
super().__init__()
|
122 |
+
assert config.vocab_size is not None
|
123 |
+
assert config.block_size is not None
|
124 |
+
self.config = config
|
125 |
+
|
126 |
+
self.transformer = nn.ModuleDict(dict(
|
127 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
128 |
+
wpe = nn.Embedding(config.block_size, config.n_embd),
|
129 |
+
drop = nn.Dropout(config.dropout),
|
130 |
+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
131 |
+
ln_f = LayerNorm(config.n_embd, bias=config.bias),
|
132 |
+
))
|
133 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
134 |
+
# with weight tying when using torch.compile() some warnings get generated:
|
135 |
+
# "UserWarning: functional_call was passed multiple values for tied weights.
|
136 |
+
# This behavior is deprecated and will be an error in future versions"
|
137 |
+
# not 100% sure what this is, so far seems to be harmless. TODO investigate
|
138 |
+
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
139 |
+
|
140 |
+
# init all weights
|
141 |
+
self.apply(self._init_weights)
|
142 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
143 |
+
for pn, p in self.named_parameters():
|
144 |
+
if pn.endswith('c_proj.weight'):
|
145 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
146 |
+
|
147 |
+
# report number of parameters
|
148 |
+
#print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
|
149 |
+
|
150 |
+
def get_num_params(self, non_embedding=True):
|
151 |
+
"""
|
152 |
+
Return the number of parameters in the model.
|
153 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
154 |
+
The token embeddings would too, except due to the parameter sharing these
|
155 |
+
params are actually used as weights in the final layer, so we include them.
|
156 |
+
"""
|
157 |
+
n_params = sum(p.numel() for p in self.parameters())
|
158 |
+
if non_embedding:
|
159 |
+
n_params -= self.transformer.wpe.weight.numel()
|
160 |
+
return n_params
|
161 |
+
|
162 |
+
def _init_weights(self, module):
|
163 |
+
if isinstance(module, nn.Linear):
|
164 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
165 |
+
if module.bias is not None:
|
166 |
+
torch.nn.init.zeros_(module.bias)
|
167 |
+
elif isinstance(module, nn.Embedding):
|
168 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
169 |
+
|
170 |
+
def forward(self, idx, targets=None):
|
171 |
+
device = idx.device
|
172 |
+
b, t = idx.size()
|
173 |
+
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
174 |
+
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
|
175 |
+
|
176 |
+
# forward the GPT model itself
|
177 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
178 |
+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
179 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
180 |
+
for block in self.transformer.h:
|
181 |
+
x = block(x)
|
182 |
+
x = self.transformer.ln_f(x)
|
183 |
+
|
184 |
+
if targets is not None:
|
185 |
+
# if we are given some desired targets also calculate the loss
|
186 |
+
logits = self.lm_head(x)
|
187 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
188 |
+
else:
|
189 |
+
# inference-time mini-optimization: only forward the lm_head on the very last position
|
190 |
+
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
191 |
+
loss = None
|
192 |
+
|
193 |
+
return logits, loss
|
194 |
+
|
195 |
+
def crop_block_size(self, block_size):
|
196 |
+
# model surgery to decrease the block size if necessary
|
197 |
+
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
|
198 |
+
# but want to use a smaller block size for some smaller, simpler model
|
199 |
+
assert block_size <= self.config.block_size
|
200 |
+
self.config.block_size = block_size
|
201 |
+
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
|
202 |
+
for block in self.transformer.h:
|
203 |
+
if hasattr(block.attn, 'bias'):
|
204 |
+
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
|
205 |
+
|
206 |
+
@classmethod
|
207 |
+
def from_pretrained(cls, model_type, override_args=None):
|
208 |
+
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
|
209 |
+
override_args = override_args or {} # default to empty dict
|
210 |
+
# only dropout can be overridden see more notes below
|
211 |
+
assert all(k == 'dropout' for k in override_args)
|
212 |
+
from transformers import GPT2LMHeadModel
|
213 |
+
print("loading weights from pretrained gpt: %s" % model_type)
|
214 |
+
|
215 |
+
# n_layer, n_head and n_embd are determined from model_type
|
216 |
+
config_args = {
|
217 |
+
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
|
218 |
+
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
|
219 |
+
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
|
220 |
+
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
|
221 |
+
}[model_type]
|
222 |
+
print("forcing vocab_size=50257, block_size=1024, bias=True")
|
223 |
+
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
|
224 |
+
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
|
225 |
+
config_args['bias'] = True # always True for GPT model checkpoints
|
226 |
+
# we can override the dropout rate, if desired
|
227 |
+
if 'dropout' in override_args:
|
228 |
+
print(f"overriding dropout rate to {override_args['dropout']}")
|
229 |
+
config_args['dropout'] = override_args['dropout']
|
230 |
+
# create a from-scratch initialized minGPT model
|
231 |
+
config = GPTConfig(**config_args)
|
232 |
+
model = GPT(config)
|
233 |
+
sd = model.state_dict()
|
234 |
+
sd_keys = sd.keys()
|
235 |
+
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
|
236 |
+
|
237 |
+
# init a huggingface/transformers model
|
238 |
+
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
239 |
+
sd_hf = model_hf.state_dict()
|
240 |
+
|
241 |
+
# copy while ensuring all of the parameters are aligned and match in names and shapes
|
242 |
+
sd_keys_hf = sd_hf.keys()
|
243 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
|
244 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
|
245 |
+
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
246 |
+
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
|
247 |
+
# this means that we have to transpose these weights when we import them
|
248 |
+
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
|
249 |
+
for k in sd_keys_hf:
|
250 |
+
if any(k.endswith(w) for w in transposed):
|
251 |
+
# special treatment for the Conv1D weights we need to transpose
|
252 |
+
assert sd_hf[k].shape[::-1] == sd[k].shape
|
253 |
+
with torch.no_grad():
|
254 |
+
sd[k].copy_(sd_hf[k].t())
|
255 |
+
else:
|
256 |
+
# vanilla copy over the other parameters
|
257 |
+
assert sd_hf[k].shape == sd[k].shape
|
258 |
+
with torch.no_grad():
|
259 |
+
sd[k].copy_(sd_hf[k])
|
260 |
+
|
261 |
+
return model
|
262 |
+
|
263 |
+
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
264 |
+
# start with all of the candidate parameters
|
265 |
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
266 |
+
# filter out those that do not require grad
|
267 |
+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
268 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
269 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
270 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
271 |
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
272 |
+
optim_groups = [
|
273 |
+
{'params': decay_params, 'weight_decay': weight_decay},
|
274 |
+
{'params': nodecay_params, 'weight_decay': 0.0}
|
275 |
+
]
|
276 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
277 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
278 |
+
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
279 |
+
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
280 |
+
# Create AdamW optimizer and use the fused version if it is available
|
281 |
+
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
282 |
+
use_fused = fused_available and device_type == 'cuda'
|
283 |
+
extra_args = dict(fused=True) if use_fused else dict()
|
284 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
285 |
+
print(f"using fused AdamW: {use_fused}")
|
286 |
+
|
287 |
+
return optimizer
|
288 |
+
|
289 |
+
def estimate_mfu(self, fwdbwd_per_iter, dt):
|
290 |
+
""" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
|
291 |
+
# first estimate the number of flops we do per iteration.
|
292 |
+
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
|
293 |
+
N = self.get_num_params()
|
294 |
+
cfg = self.config
|
295 |
+
L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
|
296 |
+
flops_per_token = 6*N + 12*L*H*Q*T
|
297 |
+
flops_per_fwdbwd = flops_per_token * T
|
298 |
+
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
299 |
+
# express our flops throughput as ratio of A100 bfloat16 peak flops
|
300 |
+
flops_achieved = flops_per_iter * (1.0/dt) # per second
|
301 |
+
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
302 |
+
mfu = flops_achieved / flops_promised
|
303 |
+
return mfu
|
304 |
+
|
305 |
+
@torch.no_grad()
|
306 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
307 |
+
"""
|
308 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
309 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
310 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
311 |
+
"""
|
312 |
+
for _ in range(max_new_tokens):
|
313 |
+
# if the sequence context is growing too long we must crop it at block_size
|
314 |
+
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
315 |
+
# forward the model to get the logits for the index in the sequence
|
316 |
+
logits, _ = self(idx_cond)
|
317 |
+
# pluck the logits at the final step and scale by desired temperature
|
318 |
+
logits = logits[:, -1, :] / temperature
|
319 |
+
# optionally crop the logits to only the top k options
|
320 |
+
if top_k is not None:
|
321 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
322 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
323 |
+
# apply softmax to convert logits to (normalized) probabilities
|
324 |
+
probs = F.softmax(logits, dim=-1)
|
325 |
+
# sample from the distribution
|
326 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
327 |
+
# append sampled index to the running sequence and continue
|
328 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
329 |
+
|
330 |
+
return idx
|
chess-mamba-vs-xformer/config/Mamba/11M.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
beta1 = 0.9
|
5 |
+
beta2 = 0.95
|
6 |
+
weight_decay = 4.5e-3
|
7 |
+
grad_clip = 0.5
|
8 |
+
auto_clip = True
|
9 |
+
auto_clip_max = 0.5
|
10 |
+
auto_clip_min = 3.333e-3
|
11 |
+
grad_clip_start_size = 100
|
12 |
+
grad_clip_max_size = 400
|
13 |
+
grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
|
14 |
+
max_seq_len = 1536
|
15 |
+
|
16 |
+
# batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
|
17 |
+
base_batch_size = 256
|
18 |
+
|
19 |
+
batch_size = 100
|
20 |
+
gradient_accumulation_steps = 1
|
21 |
+
effective_batch_size = batch_size * gradient_accumulation_steps
|
22 |
+
|
23 |
+
always_save_checkpoint = True
|
24 |
+
eval_interval = 200
|
25 |
+
eval_iters = 33
|
26 |
+
log_interval = 33
|
27 |
+
train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
|
28 |
+
|
29 |
+
warmup_iters = 500 # not super necessary potentially
|
30 |
+
learning_rate = 1e-3
|
31 |
+
min_lr = 6.6667e-5
|
32 |
+
# max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
|
33 |
+
max_iters = 400000 #~=102M games
|
34 |
+
|
35 |
+
# # # # #
|
36 |
+
|
37 |
+
warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
|
38 |
+
learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
|
39 |
+
max_iters = int(max_iters * (base_batch_size / effective_batch_size))
|
40 |
+
min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
|
41 |
+
|
42 |
+
out_dir = 'out/Mamba/11M'
|
43 |
+
eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
|
44 |
+
eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
|
45 |
+
log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
|
46 |
+
|
47 |
+
print(f'warmup iters: {warmup_iters}')
|
48 |
+
print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
|
49 |
+
print(f'Eval iters: {eval_iters}')
|
50 |
+
print(f'Eval interval: {eval_interval}')
|
51 |
+
print(f'Log interval: {log_interval}')
|
52 |
+
|
53 |
+
wandb_log = True # override via command line if you like
|
54 |
+
wandb_project = 'chess-mamba-v2'
|
55 |
+
wandb_run_name = 'Mamba-11M'
|
56 |
+
|
57 |
+
dataset = 'stable'
|
58 |
+
|
59 |
+
# 11M param
|
60 |
+
model_type = 'mamba'
|
61 |
+
n_layer = 20
|
62 |
+
d_model = 288
|
63 |
+
d_state = 16
|
64 |
+
dt_rank = 'auto' #ceil(d_model/16) ... 18 here
|
65 |
+
move_num_in_gamestate = False
|
66 |
+
|
67 |
+
init_from = 'scratch'
|
68 |
+
|
69 |
+
device = 'cuda' # run on cpu only
|
70 |
+
compile = False # do not torch compile the model
|
chess-mamba-vs-xformer/config/Mamba/250M.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
beta1 = 0.9
|
5 |
+
beta2 = 0.95
|
6 |
+
weight_decay = 4.5e-3
|
7 |
+
grad_clip = 0.5
|
8 |
+
auto_clip = True
|
9 |
+
auto_clip_max = 0.5
|
10 |
+
auto_clip_min = 3.333e-3
|
11 |
+
grad_clip_start_size = 100
|
12 |
+
grad_clip_max_size = 400
|
13 |
+
grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
|
14 |
+
max_seq_len = 1536
|
15 |
+
|
16 |
+
# batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
|
17 |
+
base_batch_size = 256
|
18 |
+
|
19 |
+
batch_size = 10
|
20 |
+
gradient_accumulation_steps = 10
|
21 |
+
effective_batch_size = batch_size * gradient_accumulation_steps
|
22 |
+
|
23 |
+
always_save_checkpoint = True
|
24 |
+
eval_interval = 300
|
25 |
+
eval_iters = 33
|
26 |
+
log_interval = 75
|
27 |
+
train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
|
28 |
+
|
29 |
+
warmup_iters = 500 # not super necessary potentially
|
30 |
+
learning_rate = 2.0e-3 # tested 1.5e-3 from 112k-156k, before that 3.5e-3 #8e-3
|
31 |
+
min_lr = 1.3333e-4
|
32 |
+
# max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
|
33 |
+
max_iters = 400000 #~=102M games
|
34 |
+
|
35 |
+
# # # # #
|
36 |
+
|
37 |
+
warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
|
38 |
+
learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
|
39 |
+
max_iters = int(max_iters * (base_batch_size / effective_batch_size))
|
40 |
+
min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
|
41 |
+
|
42 |
+
out_dir = 'out/Mamba/250M'
|
43 |
+
eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
|
44 |
+
eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
|
45 |
+
log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
|
46 |
+
|
47 |
+
print(f'warmup iters: {warmup_iters}')
|
48 |
+
print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
|
49 |
+
print(f'Eval iters: {eval_iters}')
|
50 |
+
print(f'Eval interval: {eval_interval}')
|
51 |
+
print(f'Log interval: {log_interval}')
|
52 |
+
|
53 |
+
wandb_log = True
|
54 |
+
wandb_project = 'chess-mamba-v2'
|
55 |
+
wandb_run_name = 'Mamba-250M'
|
56 |
+
|
57 |
+
dataset = 'stable'
|
58 |
+
|
59 |
+
# 251M param
|
60 |
+
model_type = 'mamba'
|
61 |
+
n_layer = 96
|
62 |
+
d_model = 578
|
63 |
+
d_state = 56
|
64 |
+
dt_rank = 176
|
65 |
+
move_num_in_gamestate = False
|
66 |
+
|
67 |
+
init_from = 'scratch'
|
68 |
+
|
69 |
+
device = 'cuda' # run on cpu only
|
70 |
+
compile = False # do not torch compile the model
|
chess-mamba-vs-xformer/config/Mamba/29M.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
beta1 = 0.9
|
5 |
+
beta2 = 0.95
|
6 |
+
weight_decay = 4.5e-3
|
7 |
+
grad_clip = 0.5
|
8 |
+
auto_clip = True
|
9 |
+
auto_clip_max = 0.5
|
10 |
+
auto_clip_min = 3.333e-3
|
11 |
+
grad_clip_start_size = 100
|
12 |
+
grad_clip_max_size = 400
|
13 |
+
grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
|
14 |
+
max_seq_len = 1536
|
15 |
+
|
16 |
+
# batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
|
17 |
+
base_batch_size = 256
|
18 |
+
|
19 |
+
batch_size = 100
|
20 |
+
gradient_accumulation_steps = 1
|
21 |
+
effective_batch_size = batch_size * gradient_accumulation_steps
|
22 |
+
|
23 |
+
always_save_checkpoint = True
|
24 |
+
eval_interval = 250
|
25 |
+
eval_iters = 33
|
26 |
+
log_interval = 50
|
27 |
+
train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
|
28 |
+
|
29 |
+
warmup_iters = 500 # not super necessary potentially
|
30 |
+
learning_rate = 1.25e-3
|
31 |
+
min_lr = 8.3333e-5
|
32 |
+
# max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
|
33 |
+
max_iters = 400000 #~=102M games
|
34 |
+
|
35 |
+
# # # # #
|
36 |
+
|
37 |
+
warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
|
38 |
+
learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
|
39 |
+
max_iters = int(max_iters * (base_batch_size / effective_batch_size))
|
40 |
+
min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
|
41 |
+
|
42 |
+
out_dir = 'out/Mamba/29M'
|
43 |
+
eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
|
44 |
+
eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
|
45 |
+
log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
|
46 |
+
|
47 |
+
print(f'warmup iters: {warmup_iters}')
|
48 |
+
print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
|
49 |
+
print(f'Eval iters: {eval_iters}')
|
50 |
+
print(f'Eval interval: {eval_interval}')
|
51 |
+
print(f'Log interval: {log_interval}')
|
52 |
+
|
53 |
+
wandb_log = True # override via command line if you like
|
54 |
+
wandb_project = 'chess-mamba-v2'
|
55 |
+
wandb_run_name = 'Mamba-29M'
|
56 |
+
|
57 |
+
dataset = 'stable'
|
58 |
+
|
59 |
+
# 29.3M param
|
60 |
+
model_type = 'mamba'
|
61 |
+
n_layer = 33
|
62 |
+
d_model = 360
|
63 |
+
d_state = 24
|
64 |
+
dt_rank = 36
|
65 |
+
move_num_in_gamestate = False
|
66 |
+
|
67 |
+
init_from = 'scratch'
|
68 |
+
|
69 |
+
device = 'cuda' # run on cpu only
|
70 |
+
compile = False # do not torch compile the model
|
chess-mamba-vs-xformer/config/Mamba/50M.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
beta1 = 0.9
|
5 |
+
beta2 = 0.95
|
6 |
+
weight_decay = 4.5e-3
|
7 |
+
grad_clip = 0.5
|
8 |
+
auto_clip = True
|
9 |
+
auto_clip_max = 0.5
|
10 |
+
auto_clip_min = 3.333e-3
|
11 |
+
grad_clip_start_size = 100
|
12 |
+
grad_clip_max_size = 400
|
13 |
+
grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
|
14 |
+
max_seq_len = 1536
|
15 |
+
|
16 |
+
# batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
|
17 |
+
base_batch_size = 256
|
18 |
+
|
19 |
+
batch_size = 50
|
20 |
+
gradient_accumulation_steps = 2
|
21 |
+
effective_batch_size = batch_size * gradient_accumulation_steps
|
22 |
+
|
23 |
+
always_save_checkpoint = True
|
24 |
+
eval_interval = 250
|
25 |
+
eval_iters = 33
|
26 |
+
log_interval = 50
|
27 |
+
train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
|
28 |
+
|
29 |
+
warmup_iters = 500 # not super necessary potentially
|
30 |
+
learning_rate = 1.5e-3 # tested 1.5e-3 from 112k-156k, before that 3.5e-3 #8e-3
|
31 |
+
min_lr = 1.0e-4 # was planning 8.5e-5 w/ /6.75 anneal #... before 2e-4 # 4.75e-4
|
32 |
+
# max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
|
33 |
+
max_iters = 400000 #~=102M games
|
34 |
+
|
35 |
+
# # # # #
|
36 |
+
|
37 |
+
warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
|
38 |
+
learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
|
39 |
+
max_iters = int(max_iters * (base_batch_size / effective_batch_size))
|
40 |
+
min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
|
41 |
+
|
42 |
+
out_dir = 'out/Mamba/50M'
|
43 |
+
eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
|
44 |
+
eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
|
45 |
+
log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
|
46 |
+
|
47 |
+
print(f'warmup iters: {warmup_iters}')
|
48 |
+
print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
|
49 |
+
print(f'Eval iters: {eval_iters}')
|
50 |
+
print(f'Eval interval: {eval_interval}')
|
51 |
+
print(f'Log interval: {log_interval}')
|
52 |
+
|
53 |
+
wandb_log = True
|
54 |
+
wandb_project = 'chess-mamba-v2'
|
55 |
+
wandb_run_name = 'Mamba-50M'
|
56 |
+
|
57 |
+
dataset = 'stable'
|
58 |
+
|
59 |
+
# 50.4M param
|
60 |
+
model_type = 'mamba'
|
61 |
+
n_layer = 48
|
62 |
+
d_model = 384
|
63 |
+
d_state = 32
|
64 |
+
dt_rank = 56
|
65 |
+
move_num_in_gamestate = False
|
66 |
+
|
67 |
+
init_from = 'scratch'
|
68 |
+
|
69 |
+
device = 'cuda' # run on cpu only
|
70 |
+
compile = False # do not torch compile the model
|
chess-mamba-vs-xformer/config/Mamba/6.6M.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
beta1 = 0.9
|
5 |
+
beta2 = 0.95
|
6 |
+
weight_decay = 4.5e-3
|
7 |
+
grad_clip = 0.5
|
8 |
+
auto_clip = True
|
9 |
+
auto_clip_max = 0.5
|
10 |
+
auto_clip_min = 3.333e-3
|
11 |
+
grad_clip_start_size = 100
|
12 |
+
grad_clip_max_size = 400
|
13 |
+
grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
|
14 |
+
max_seq_len = 1536
|
15 |
+
|
16 |
+
# batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
|
17 |
+
base_batch_size = 256
|
18 |
+
|
19 |
+
batch_size = 100
|
20 |
+
gradient_accumulation_steps = 1
|
21 |
+
effective_batch_size = batch_size * gradient_accumulation_steps
|
22 |
+
|
23 |
+
always_save_checkpoint = True
|
24 |
+
eval_interval = 200
|
25 |
+
eval_iters = 33
|
26 |
+
log_interval = 33
|
27 |
+
train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
|
28 |
+
|
29 |
+
warmup_iters = 500 # not super necessary potentially
|
30 |
+
learning_rate = 8.16667e-4
|
31 |
+
min_lr = 5.4444e-5
|
32 |
+
# max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
|
33 |
+
max_iters = 400000 #~=102M games
|
34 |
+
|
35 |
+
# # # # #
|
36 |
+
|
37 |
+
warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
|
38 |
+
learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
|
39 |
+
max_iters = int(max_iters * (base_batch_size / effective_batch_size))
|
40 |
+
min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
|
41 |
+
|
42 |
+
out_dir = 'out/Mamba/6.6M'
|
43 |
+
eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
|
44 |
+
eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
|
45 |
+
log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
|
46 |
+
|
47 |
+
print(f'warmup iters: {warmup_iters}')
|
48 |
+
print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
|
49 |
+
print(f'Eval iters: {eval_iters}')
|
50 |
+
print(f'Eval interval: {eval_interval}')
|
51 |
+
print(f'Log interval: {log_interval}')
|
52 |
+
|
53 |
+
wandb_log = True # override via command line if you like
|
54 |
+
wandb_project = 'chess-mamba-v2'
|
55 |
+
wandb_run_name = 'Mamba-6.6M'
|
56 |
+
|
57 |
+
dataset = 'stable'
|
58 |
+
|
59 |
+
# 6.6M param
|
60 |
+
model_type = 'mamba'
|
61 |
+
n_layer = 15
|
62 |
+
d_model = 256
|
63 |
+
d_state = 16
|
64 |
+
dt_rank = 'auto' #ceil(d_model/16) ... 16 here
|
65 |
+
move_num_in_gamestate = False
|
66 |
+
|
67 |
+
init_from = 'scratch'
|
68 |
+
|
69 |
+
device = 'cuda' # run on cpu only
|
70 |
+
compile = False # do not torch compile the model
|
chess-mamba-vs-xformer/config/Xformer/11M.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
beta1 = 0.9
|
5 |
+
beta2 = 0.95
|
6 |
+
weight_decay = 4.5e-3
|
7 |
+
grad_clip = 0.5
|
8 |
+
auto_clip = True
|
9 |
+
auto_clip_max = 0.5
|
10 |
+
auto_clip_min = 3.333e-3
|
11 |
+
grad_clip_start_size = 100
|
12 |
+
grad_clip_max_size = 400
|
13 |
+
grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
|
14 |
+
max_seq_len = 1536
|
15 |
+
|
16 |
+
# batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
|
17 |
+
base_batch_size = 100
|
18 |
+
|
19 |
+
batch_size = 100
|
20 |
+
gradient_accumulation_steps = 1
|
21 |
+
effective_batch_size = batch_size * gradient_accumulation_steps
|
22 |
+
|
23 |
+
always_save_checkpoint = True
|
24 |
+
eval_interval = 600
|
25 |
+
eval_iters = 100
|
26 |
+
log_interval = 100
|
27 |
+
train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
|
28 |
+
|
29 |
+
warmup_iters = 1280 # not super necessary potentially
|
30 |
+
learning_rate = 2e-4
|
31 |
+
min_lr = 1.33333e-5
|
32 |
+
# max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
|
33 |
+
max_iters = 1024000 #~=102M games
|
34 |
+
|
35 |
+
# # # # #
|
36 |
+
|
37 |
+
warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
|
38 |
+
learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
|
39 |
+
max_iters = int(max_iters * (base_batch_size / effective_batch_size))
|
40 |
+
min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
|
41 |
+
|
42 |
+
out_dir = 'out/Xformer/11M'
|
43 |
+
eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
|
44 |
+
eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
|
45 |
+
log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
|
46 |
+
|
47 |
+
print(f'warmup iters: {warmup_iters}')
|
48 |
+
print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
|
49 |
+
print(f'Eval iters: {eval_iters}')
|
50 |
+
print(f'Eval interval: {eval_interval}')
|
51 |
+
print(f'Log interval: {log_interval}')
|
52 |
+
|
53 |
+
wandb_log = True # override via command line if you like
|
54 |
+
wandb_project = 'chess-xformer'
|
55 |
+
wandb_run_name = 'Xformer-11M'
|
56 |
+
|
57 |
+
dataset = 'stable'
|
58 |
+
|
59 |
+
# 11.2M param
|
60 |
+
model_type = 'xformer'
|
61 |
+
n_layer = 6
|
62 |
+
n_head = 6
|
63 |
+
n_embd = 384
|
64 |
+
dropout = 0.0
|
65 |
+
move_num_in_gamestate = False
|
66 |
+
|
67 |
+
init_from = 'scratch'
|
68 |
+
|
69 |
+
device = 'cuda' # run on cpu only
|
70 |
+
compile = False # do not torch compile the model
|
chess-mamba-vs-xformer/config/Xformer/250M.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
beta1 = 0.9
|
5 |
+
beta2 = 0.95
|
6 |
+
weight_decay = 4.5e-3
|
7 |
+
grad_clip = 0.5
|
8 |
+
auto_clip = True
|
9 |
+
auto_clip_max = 0.5
|
10 |
+
auto_clip_min = 3.333e-3
|
11 |
+
grad_clip_start_size = 100
|
12 |
+
grad_clip_max_size = 400
|
13 |
+
grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
|
14 |
+
max_seq_len = 1536
|
15 |
+
|
16 |
+
# batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
|
17 |
+
base_batch_size = 100
|
18 |
+
|
19 |
+
batch_size = 10
|
20 |
+
gradient_accumulation_steps = 10
|
21 |
+
effective_batch_size = batch_size * gradient_accumulation_steps
|
22 |
+
|
23 |
+
always_save_checkpoint = True
|
24 |
+
eval_interval = 900
|
25 |
+
eval_iters = 100
|
26 |
+
log_interval = 225
|
27 |
+
train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
|
28 |
+
|
29 |
+
warmup_iters = 1280 # not super necessary potentially
|
30 |
+
learning_rate = 4e-4
|
31 |
+
min_lr = 2.6667e-5
|
32 |
+
# max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
|
33 |
+
max_iters = 1024000 #~=102M games
|
34 |
+
|
35 |
+
# # # # #
|
36 |
+
|
37 |
+
warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
|
38 |
+
learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
|
39 |
+
max_iters = int(max_iters * (base_batch_size / effective_batch_size))
|
40 |
+
min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
|
41 |
+
|
42 |
+
out_dir = 'out/Xformer/250M'
|
43 |
+
eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
|
44 |
+
eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
|
45 |
+
log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
|
46 |
+
|
47 |
+
print(f'warmup iters: {warmup_iters}')
|
48 |
+
print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
|
49 |
+
print(f'Eval iters: {eval_iters}')
|
50 |
+
print(f'Eval interval: {eval_interval}')
|
51 |
+
print(f'Log interval: {log_interval}')
|
52 |
+
|
53 |
+
wandb_log = True
|
54 |
+
wandb_project = 'chess-xformer'
|
55 |
+
wandb_run_name = 'Xformer-250M'
|
56 |
+
|
57 |
+
dataset = 'stable'
|
58 |
+
|
59 |
+
# 251.2M param
|
60 |
+
model_type = 'xformer'
|
61 |
+
n_layer = 51
|
62 |
+
n_head = 16
|
63 |
+
n_embd = 640
|
64 |
+
dropout = 0.0
|
65 |
+
move_num_in_gamestate = False
|
66 |
+
|
67 |
+
init_from = 'scratch'
|
68 |
+
|
69 |
+
device = 'cuda' # run on cpu only
|
70 |
+
compile = False # do not torch compile the model
|
chess-mamba-vs-xformer/config/Xformer/29M.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
beta1 = 0.9
|
5 |
+
beta2 = 0.95
|
6 |
+
weight_decay = 4.5e-3
|
7 |
+
grad_clip = 0.5
|
8 |
+
auto_clip = True
|
9 |
+
auto_clip_max = 0.5
|
10 |
+
auto_clip_min = 3.333e-3
|
11 |
+
grad_clip_start_size = 100
|
12 |
+
grad_clip_max_size = 400
|
13 |
+
grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
|
14 |
+
max_seq_len = 1536
|
15 |
+
|
16 |
+
# batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
|
17 |
+
base_batch_size = 100
|
18 |
+
|
19 |
+
batch_size = 100
|
20 |
+
gradient_accumulation_steps = 1
|
21 |
+
effective_batch_size = batch_size * gradient_accumulation_steps
|
22 |
+
|
23 |
+
always_save_checkpoint = True
|
24 |
+
eval_interval = 750
|
25 |
+
eval_iters = 100
|
26 |
+
log_interval = 150
|
27 |
+
train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
|
28 |
+
|
29 |
+
warmup_iters = 1280 # not super necessary potentially
|
30 |
+
learning_rate = 2.5e-4
|
31 |
+
min_lr = 1.6667e-5
|
32 |
+
# max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
|
33 |
+
max_iters = 1024000 #~=102M games
|
34 |
+
|
35 |
+
# # # # #
|
36 |
+
|
37 |
+
warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
|
38 |
+
learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
|
39 |
+
max_iters = int(max_iters * (base_batch_size / effective_batch_size))
|
40 |
+
min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
|
41 |
+
|
42 |
+
out_dir = 'out/Xformer/29M'
|
43 |
+
eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
|
44 |
+
eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
|
45 |
+
log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
|
46 |
+
|
47 |
+
print(f'warmup iters: {warmup_iters}')
|
48 |
+
print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
|
49 |
+
print(f'Eval iters: {eval_iters}')
|
50 |
+
print(f'Eval interval: {eval_interval}')
|
51 |
+
print(f'Log interval: {log_interval}')
|
52 |
+
|
53 |
+
wandb_log = True # override via command line if you like
|
54 |
+
wandb_project = 'chess-xformer'
|
55 |
+
wandb_run_name = 'Xformer-29M'
|
56 |
+
|
57 |
+
dataset = 'stable'
|
58 |
+
|
59 |
+
# 29.1M param
|
60 |
+
model_type = 'xformer'
|
61 |
+
n_layer = 9
|
62 |
+
n_head = 8
|
63 |
+
n_embd = 512
|
64 |
+
dropout = 0.0
|
65 |
+
move_num_in_gamestate = False
|
66 |
+
|
67 |
+
init_from = 'scratch'
|
68 |
+
|
69 |
+
device = 'cuda' # run on cpu only
|
70 |
+
compile = False # do not torch compile the model
|
chess-mamba-vs-xformer/config/Xformer/50M.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
beta1 = 0.9
|
5 |
+
beta2 = 0.95
|
6 |
+
weight_decay = 4.5e-3
|
7 |
+
grad_clip = 0.5
|
8 |
+
auto_clip = True
|
9 |
+
auto_clip_max = 0.5
|
10 |
+
auto_clip_min = 3.333e-3
|
11 |
+
grad_clip_start_size = 100
|
12 |
+
grad_clip_max_size = 400
|
13 |
+
grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
|
14 |
+
max_seq_len = 1536
|
15 |
+
|
16 |
+
# batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
|
17 |
+
base_batch_size = 100
|
18 |
+
|
19 |
+
batch_size = 50
|
20 |
+
gradient_accumulation_steps = 2
|
21 |
+
effective_batch_size = batch_size * gradient_accumulation_steps
|
22 |
+
|
23 |
+
always_save_checkpoint = True
|
24 |
+
eval_interval = 750
|
25 |
+
eval_iters = 100
|
26 |
+
log_interval = 150
|
27 |
+
train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
|
28 |
+
|
29 |
+
warmup_iters = 1280 # not super necessary potentially
|
30 |
+
learning_rate = 3e-4 # Mamba is 9.375e-4 (adjusting for different base_batch_size)
|
31 |
+
min_lr = 2e-5 # Same ratio min/max as w/ Mamba. It's lower than 1/10 because doing slightly long anneal.
|
32 |
+
# max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
|
33 |
+
max_iters = 1024000 #~=102M games
|
34 |
+
|
35 |
+
# # # # #
|
36 |
+
|
37 |
+
warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
|
38 |
+
learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
|
39 |
+
max_iters = int(max_iters * (base_batch_size / effective_batch_size))
|
40 |
+
min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
|
41 |
+
|
42 |
+
out_dir = 'out/Xformer/50M'
|
43 |
+
eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
|
44 |
+
eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
|
45 |
+
log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
|
46 |
+
|
47 |
+
print(f'warmup iters: {warmup_iters}')
|
48 |
+
print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
|
49 |
+
print(f'Eval iters: {eval_iters}')
|
50 |
+
print(f'Eval interval: {eval_interval}')
|
51 |
+
print(f'Log interval: {log_interval}')
|
52 |
+
|
53 |
+
wandb_log = True
|
54 |
+
wandb_project = 'chess-xformer'
|
55 |
+
wandb_run_name = 'Xformer-50M'
|
56 |
+
|
57 |
+
dataset = 'stable'
|
58 |
+
|
59 |
+
# 50.8M param
|
60 |
+
model_type = 'xformer'
|
61 |
+
n_layer = 16
|
62 |
+
n_head = 8
|
63 |
+
n_embd = 512
|
64 |
+
dropout = 0.0
|
65 |
+
move_num_in_gamestate = False
|
66 |
+
|
67 |
+
init_from = 'scratch'
|
68 |
+
|
69 |
+
device = 'cuda' # run on cpu only
|
70 |
+
compile = False # do not torch compile the model
|
chess-mamba-vs-xformer/config/Xformer/6.6M.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
beta1 = 0.9
|
5 |
+
beta2 = 0.95
|
6 |
+
weight_decay = 4.5e-3
|
7 |
+
grad_clip = 0.5
|
8 |
+
auto_clip = True
|
9 |
+
auto_clip_max = 0.5
|
10 |
+
auto_clip_min = 3.333e-3
|
11 |
+
grad_clip_start_size = 100
|
12 |
+
grad_clip_max_size = 400
|
13 |
+
grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
|
14 |
+
max_seq_len = 1536
|
15 |
+
|
16 |
+
# batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
|
17 |
+
base_batch_size = 100
|
18 |
+
|
19 |
+
batch_size = 100
|
20 |
+
gradient_accumulation_steps = 1
|
21 |
+
effective_batch_size = batch_size * gradient_accumulation_steps
|
22 |
+
|
23 |
+
always_save_checkpoint = True
|
24 |
+
eval_interval = 600
|
25 |
+
eval_iters = 100
|
26 |
+
log_interval = 100
|
27 |
+
train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
|
28 |
+
|
29 |
+
warmup_iters = 1280 # not super necessary potentially
|
30 |
+
learning_rate = 1.633333e-4
|
31 |
+
min_lr = 1.08889e-5
|
32 |
+
# max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
|
33 |
+
max_iters = 1024000 #~=102M games
|
34 |
+
|
35 |
+
# # # # #
|
36 |
+
|
37 |
+
warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
|
38 |
+
learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
|
39 |
+
max_iters = int(max_iters * (base_batch_size / effective_batch_size))
|
40 |
+
min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
|
41 |
+
|
42 |
+
out_dir = 'out/Xformer/6.6M'
|
43 |
+
eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
|
44 |
+
eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
|
45 |
+
log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
|
46 |
+
|
47 |
+
print(f'warmup iters: {warmup_iters}')
|
48 |
+
print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
|
49 |
+
print(f'Eval iters: {eval_iters}')
|
50 |
+
print(f'Eval interval: {eval_interval}')
|
51 |
+
print(f'Log interval: {log_interval}')
|
52 |
+
|
53 |
+
wandb_log = True # override via command line if you like
|
54 |
+
wandb_project = 'chess-xformer'
|
55 |
+
wandb_run_name = 'Xformer-6.6M'
|
56 |
+
|
57 |
+
dataset = 'stable'
|
58 |
+
|
59 |
+
# 6.6M param
|
60 |
+
model_type = 'xformer'
|
61 |
+
n_layer = 5
|
62 |
+
n_head = 5
|
63 |
+
n_embd = 320
|
64 |
+
dropout = 0.0
|
65 |
+
move_num_in_gamestate = False
|
66 |
+
|
67 |
+
init_from = 'resume'
|
68 |
+
|
69 |
+
device = 'cuda' # run on cpu only
|
70 |
+
compile = False # do not torch compile the model
|
chess-mamba-vs-xformer/configurator.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Poor Man's Configurator. Probably a terrible idea. Example usage:
|
3 |
+
$ python train.py config/override_file.py --batch_size=32
|
4 |
+
this will first run config/override_file.py, then override batch_size to 32
|
5 |
+
|
6 |
+
The code in this file will be run as follows from e.g. train.py:
|
7 |
+
>>> exec(open('configurator.py').read())
|
8 |
+
|
9 |
+
So it's not a Python module, it's just shuttling this code away from train.py
|
10 |
+
The code in this script then overrides the globals()
|
11 |
+
|
12 |
+
I know people are not going to love this, I just really dislike configuration
|
13 |
+
complexity and having to prepend config. to every single variable. If someone
|
14 |
+
comes up with a better simple Python solution I am all ears.
|
15 |
+
"""
|
16 |
+
|
17 |
+
import sys
|
18 |
+
from ast import literal_eval
|
19 |
+
|
20 |
+
for arg in sys.argv[1:]:
|
21 |
+
if '=' not in arg:
|
22 |
+
# assume it's the name of a config file
|
23 |
+
assert not arg.startswith('--')
|
24 |
+
config_file = arg
|
25 |
+
print(f"Overriding config with {config_file}:")
|
26 |
+
with open(config_file) as f:
|
27 |
+
print(f.read())
|
28 |
+
exec(open(config_file).read())
|
29 |
+
else:
|
30 |
+
# assume it's a --key=value argument
|
31 |
+
assert arg.startswith('--')
|
32 |
+
key, val = arg.split('=')
|
33 |
+
key = key[2:]
|
34 |
+
if key in globals():
|
35 |
+
try:
|
36 |
+
# attempt to eval it it (e.g. if bool, number, or etc)
|
37 |
+
attempt = literal_eval(val)
|
38 |
+
except (SyntaxError, ValueError):
|
39 |
+
# if that goes wrong, just use the string
|
40 |
+
attempt = val
|
41 |
+
# ensure the types match ok
|
42 |
+
assert type(attempt) == type(globals()[key])
|
43 |
+
# cross fingers
|
44 |
+
print(f"Overriding: {key} = {attempt}")
|
45 |
+
globals()[key] = attempt
|
46 |
+
else:
|
47 |
+
raise ValueError(f"Unknown config key: {key}")
|
chess-mamba-vs-xformer/data/anneal/anneal.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37569f7bafca0eb2a7361e4ae29ef1b9fed64dbeb061d2653215c348586f7a7e
|
3 |
+
size 679959998
|
chess-mamba-vs-xformer/mamba.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from pscan import pscan
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
This file closely follows the mamba_simple.py from the official Mamba implementation, and the mamba-minimal by @johnma2006.
|
14 |
+
The major differences are :
|
15 |
+
-the convolution is done with torch.nn.Conv1d
|
16 |
+
-the selective scan is done in PyTorch
|
17 |
+
|
18 |
+
A sequential version of the selective scan is also available for comparison.
|
19 |
+
|
20 |
+
- A Mamba model is composed of several layers, which are ResidualBlock.
|
21 |
+
- A ResidualBlock is composed of a MambaBlock, a normalization, and a residual connection : ResidualBlock(x) = mamba(norm(x)) + x
|
22 |
+
- This leaves us with the MambaBlock : its input x is (B, L, D) and its outputs y is also (B, L, D) (B=batch size, L=seq len, D=model dim).
|
23 |
+
First, we expand x into (B, L, 2*ED) (where E is usually 2) and split it into x and z, each (B, L, ED).
|
24 |
+
Then, we apply the short 1d conv to x, followed by an activation function (silu), then the SSM.
|
25 |
+
We then multiply it by silu(z).
|
26 |
+
See Figure 3 of the paper (page 8) for a visual representation of a MambaBlock.
|
27 |
+
|
28 |
+
"""
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class MambaConfig:
|
32 |
+
d_model: int # D
|
33 |
+
n_layers: int
|
34 |
+
dt_rank: Union[int, str] = 'auto'
|
35 |
+
d_state: int = 16 # N in paper/comments
|
36 |
+
expand_factor: int = 2 # E in paper/comments
|
37 |
+
d_conv: int = 4
|
38 |
+
|
39 |
+
dt_min: float = 0.001
|
40 |
+
dt_max: float = 0.1
|
41 |
+
dt_init: str = "random" # "random" or "constant"
|
42 |
+
dt_scale: float = 1.0
|
43 |
+
dt_init_floor = 1e-4
|
44 |
+
|
45 |
+
bias: bool = False
|
46 |
+
conv_bias: bool = True
|
47 |
+
|
48 |
+
pscan: bool = True # use parallel scan mode or sequential mode when training
|
49 |
+
|
50 |
+
def __post_init__(self):
|
51 |
+
self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
|
52 |
+
|
53 |
+
if self.dt_rank == 'auto':
|
54 |
+
self.dt_rank = math.ceil(self.d_model / 16)
|
55 |
+
|
56 |
+
class Mamba(nn.Module):
|
57 |
+
def __init__(self, config: MambaConfig):
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
self.config = config
|
61 |
+
|
62 |
+
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
|
63 |
+
#self.norm_f = RMSNorm(config.d_model)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
# x : (B, L, D)
|
67 |
+
|
68 |
+
# y : (B, L, D)
|
69 |
+
|
70 |
+
for layer in self.layers:
|
71 |
+
x = layer(x)
|
72 |
+
|
73 |
+
#x = self.norm_f(x)
|
74 |
+
|
75 |
+
return x
|
76 |
+
|
77 |
+
def step(self, x, caches):
|
78 |
+
# x : (B, L, D)
|
79 |
+
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
80 |
+
|
81 |
+
# y : (B, L, D)
|
82 |
+
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
83 |
+
|
84 |
+
for i, layer in enumerate(self.layers):
|
85 |
+
x, caches[i] = layer.step(x, caches[i])
|
86 |
+
|
87 |
+
return x, caches
|
88 |
+
|
89 |
+
class ResidualBlock(nn.Module):
|
90 |
+
def __init__(self, config: MambaConfig):
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
self.mixer = MambaBlock(config)
|
94 |
+
self.norm = RMSNorm(config.d_model)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
# x : (B, L, D)
|
98 |
+
|
99 |
+
# output : (B, L, D)
|
100 |
+
|
101 |
+
output = self.mixer(self.norm(x)) + x
|
102 |
+
return output
|
103 |
+
|
104 |
+
def step(self, x, cache):
|
105 |
+
# x : (B, D)
|
106 |
+
# cache : (h, inputs)
|
107 |
+
# h : (B, ED, N)
|
108 |
+
# inputs: (B, ED, d_conv-1)
|
109 |
+
|
110 |
+
# output : (B, D)
|
111 |
+
# cache : (h, inputs)
|
112 |
+
|
113 |
+
output, cache = self.mixer.step(self.norm(x), cache)
|
114 |
+
output = output + x
|
115 |
+
return output, cache
|
116 |
+
|
117 |
+
class MambaBlock(nn.Module):
|
118 |
+
def __init__(self, config: MambaConfig):
|
119 |
+
super().__init__()
|
120 |
+
|
121 |
+
self.config = config
|
122 |
+
|
123 |
+
# projects block input from D to 2*ED (two branches)
|
124 |
+
self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
|
125 |
+
|
126 |
+
self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
|
127 |
+
kernel_size=config.d_conv, bias=config.conv_bias,
|
128 |
+
groups=config.d_inner,
|
129 |
+
padding=config.d_conv - 1)
|
130 |
+
|
131 |
+
nn.init.kaiming_normal_(self.conv1d.weight, mode='fan_out', nonlinearity='leaky_relu')
|
132 |
+
|
133 |
+
# projects x to input-dependent Δ, B, C
|
134 |
+
self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
|
135 |
+
|
136 |
+
# projects Δ from dt_rank to d_inner
|
137 |
+
self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
|
138 |
+
|
139 |
+
# dt initialization
|
140 |
+
# dt weights
|
141 |
+
dt_init_std = config.dt_rank**-0.5 * config.dt_scale
|
142 |
+
if config.dt_init == "constant":
|
143 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
144 |
+
elif config.dt_init == "random":
|
145 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
146 |
+
else:
|
147 |
+
raise NotImplementedError
|
148 |
+
|
149 |
+
# dt bias
|
150 |
+
dt = torch.exp(
|
151 |
+
torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min)
|
152 |
+
).clamp(min=config.dt_init_floor)
|
153 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
154 |
+
with torch.no_grad():
|
155 |
+
self.dt_proj.bias.copy_(inv_dt)
|
156 |
+
#self.dt_proj.bias._no_reinit = True # initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
157 |
+
# todo : explain why removed
|
158 |
+
|
159 |
+
# S4D real initialization
|
160 |
+
A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
|
161 |
+
self.A_log = nn.Parameter(torch.log(A)) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
|
162 |
+
self.D = nn.Parameter(torch.ones(config.d_inner))
|
163 |
+
|
164 |
+
# projects block output from ED back to D
|
165 |
+
self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
# x : (B, L, D)
|
169 |
+
|
170 |
+
# y : (B, L, D)
|
171 |
+
|
172 |
+
_, L, _ = x.shape
|
173 |
+
|
174 |
+
xz = self.in_proj(x) # (B, L, 2*ED)
|
175 |
+
x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)
|
176 |
+
|
177 |
+
# x branch
|
178 |
+
x = x.transpose(1, 2) # (B, ED, L)
|
179 |
+
x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
|
180 |
+
x = x.transpose(1, 2) # (B, L, ED)
|
181 |
+
|
182 |
+
x = F.silu(x)
|
183 |
+
y = self.ssm(x)
|
184 |
+
|
185 |
+
# z branch
|
186 |
+
z = F.silu(z)
|
187 |
+
|
188 |
+
output = y * z
|
189 |
+
output = self.out_proj(output) # (B, L, D)
|
190 |
+
|
191 |
+
return output
|
192 |
+
|
193 |
+
def ssm(self, x):
|
194 |
+
# x : (B, L, ED)
|
195 |
+
|
196 |
+
# y : (B, L, ED)
|
197 |
+
|
198 |
+
A = -torch.exp(self.A_log.float()) # (ED, N)
|
199 |
+
D = self.D.float()
|
200 |
+
# TODO remove .float()
|
201 |
+
|
202 |
+
deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N)
|
203 |
+
|
204 |
+
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
|
205 |
+
delta = F.softplus(self.dt_proj(delta)) # (B, L, ED)
|
206 |
+
|
207 |
+
if self.config.pscan:
|
208 |
+
y = self.selective_scan(x, delta, A, B, C, D)
|
209 |
+
else:
|
210 |
+
y = self.selective_scan_seq(x, delta, A, B, C, D)
|
211 |
+
|
212 |
+
return y
|
213 |
+
|
214 |
+
def selective_scan(self, x, delta, A, B, C, D):
|
215 |
+
# x : (B, L, ED)
|
216 |
+
# Δ : (B, L, ED)
|
217 |
+
# A : (ED, N)
|
218 |
+
# B : (B, L, N)
|
219 |
+
# C : (B, L, N)
|
220 |
+
# D : (ED)
|
221 |
+
|
222 |
+
# y : (B, L, ED)
|
223 |
+
|
224 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
|
225 |
+
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
|
226 |
+
|
227 |
+
BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
|
228 |
+
|
229 |
+
hs = pscan(deltaA, BX)
|
230 |
+
|
231 |
+
y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
|
232 |
+
|
233 |
+
y = y + D * x
|
234 |
+
|
235 |
+
return y
|
236 |
+
|
237 |
+
def selective_scan_seq(self, x, delta, A, B, C, D):
|
238 |
+
# x : (B, L, ED)
|
239 |
+
# Δ : (B, L, ED)
|
240 |
+
# A : (ED, N)
|
241 |
+
# B : (B, L, N)
|
242 |
+
# C : (B, L, N)
|
243 |
+
# D : (ED)
|
244 |
+
|
245 |
+
# y : (B, L, ED)
|
246 |
+
|
247 |
+
_, L, _ = x.shape
|
248 |
+
|
249 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
|
250 |
+
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
|
251 |
+
|
252 |
+
BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
|
253 |
+
|
254 |
+
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
|
255 |
+
hs = []
|
256 |
+
|
257 |
+
for t in range(0, L):
|
258 |
+
h = deltaA[:, t] * h + BX[:, t]
|
259 |
+
hs.append(h)
|
260 |
+
|
261 |
+
hs = torch.stack(hs, dim=1) # (B, L, ED, N)
|
262 |
+
|
263 |
+
y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
|
264 |
+
|
265 |
+
y = y + D * x
|
266 |
+
|
267 |
+
return y
|
268 |
+
|
269 |
+
# -------------------------- inference -------------------------- #
|
270 |
+
"""
|
271 |
+
Concerning auto-regressive inference
|
272 |
+
|
273 |
+
The cool part of using Mamba : inference is constant wrt to sequence length
|
274 |
+
We just have to keep in cache, for each layer, two things :
|
275 |
+
- the hidden state h (which is (B, ED, N)), as you typically would when doing inference with a RNN
|
276 |
+
- the last d_conv-1 inputs of the layer, to be able to compute the 1D conv which is a convolution over the time dimension
|
277 |
+
(d_conv is fixed so this doesn't incur a growing cache as we progress on generating the sequence)
|
278 |
+
(and d_conv is usually very small, like 4, so we just have to "remember" the last 3 inputs)
|
279 |
+
|
280 |
+
Concretely, these two quantities are put inside a cache tuple, and are named h and inputs respectively.
|
281 |
+
h is (B, ED, N), and inputs is (B, ED, d_conv-1)
|
282 |
+
The MambaBlock.step() receives this cache, and, along with outputing the output, alos outputs the updated cache for the next call.
|
283 |
+
|
284 |
+
The cache object is initialized as follows : (None, torch.zeros()).
|
285 |
+
When h is None, the selective scan function detects it and start with h=0.
|
286 |
+
The torch.zeros() isn't a problem (it's same as just feeding the input, because the conv1d is padded)
|
287 |
+
|
288 |
+
As we need one such cache variable per layer, we store a caches object, which is simply a list of cache object. (See mamba_lm.py)
|
289 |
+
"""
|
290 |
+
|
291 |
+
def step(self, x, cache):
|
292 |
+
# x : (B, D)
|
293 |
+
# cache : (h, inputs)
|
294 |
+
# h : (B, ED, N)
|
295 |
+
# inputs : (B, ED, d_conv-1)
|
296 |
+
|
297 |
+
# y : (B, D)
|
298 |
+
# cache : (h, inputs)
|
299 |
+
|
300 |
+
h, inputs = cache
|
301 |
+
|
302 |
+
xz = self.in_proj(x) # (B, 2*ED)
|
303 |
+
x, z = xz.chunk(2, dim=1) # (B, ED), (B, ED)
|
304 |
+
|
305 |
+
# x branch
|
306 |
+
x_cache = x.unsqueeze(2)
|
307 |
+
x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] # (B, ED)
|
308 |
+
|
309 |
+
x = F.silu(x)
|
310 |
+
y, h = self.ssm_step(x, h)
|
311 |
+
|
312 |
+
# z branch
|
313 |
+
z = F.silu(z)
|
314 |
+
|
315 |
+
output = y * z
|
316 |
+
output = self.out_proj(output) # (B, D)
|
317 |
+
|
318 |
+
# prepare cache for next call
|
319 |
+
inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
|
320 |
+
cache = (h, inputs)
|
321 |
+
|
322 |
+
return output, cache
|
323 |
+
|
324 |
+
def ssm_step(self, x, h):
|
325 |
+
# x : (B, ED)
|
326 |
+
# h : (B, ED, N)
|
327 |
+
|
328 |
+
# y : (B, ED)
|
329 |
+
# h : (B, ED, N)
|
330 |
+
|
331 |
+
A = -torch.exp(self.A_log.float()) # (ED, N) # todo : ne pas le faire tout le temps, puisque c'est indépendant de la timestep
|
332 |
+
D = self.D.float()
|
333 |
+
# TODO remove .float()
|
334 |
+
|
335 |
+
deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
|
336 |
+
|
337 |
+
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, dt_rank), (B, N), (B, N)
|
338 |
+
delta = F.softplus(self.dt_proj(delta)) # (B, ED)
|
339 |
+
|
340 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, ED, N)
|
341 |
+
deltaB = delta.unsqueeze(-1) * B.unsqueeze(1) # (B, ED, N)
|
342 |
+
|
343 |
+
BX = deltaB * (x.unsqueeze(-1)) # (B, ED, N)
|
344 |
+
|
345 |
+
if h is None:
|
346 |
+
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
|
347 |
+
|
348 |
+
h = deltaA * h + BX # (B, ED, N)
|
349 |
+
|
350 |
+
y = (h @ C.unsqueeze(-1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)
|
351 |
+
|
352 |
+
y = y + D * x
|
353 |
+
|
354 |
+
# todo : pq h.squeeze(1) ??
|
355 |
+
return y, h.squeeze(1)
|
356 |
+
|
357 |
+
# taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
|
358 |
+
class RMSNorm(nn.Module):
|
359 |
+
def __init__(self, d_model: int, eps: float = 1e-5):
|
360 |
+
super().__init__()
|
361 |
+
|
362 |
+
self.eps = eps
|
363 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
364 |
+
|
365 |
+
def forward(self, x):
|
366 |
+
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
367 |
+
|
368 |
+
return output
|
chess-mamba-vs-xformer/mamba_lm.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, fields, asdict
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from mamba import Mamba, MambaConfig, RMSNorm
|
9 |
+
|
10 |
+
"""
|
11 |
+
|
12 |
+
Encapsulates a Mamba model as language model. It has an embedding layer, and a LM head which maps the model output to logits.
|
13 |
+
|
14 |
+
"""
|
15 |
+
|
16 |
+
# TODO generate function : batch size != 1 ? (for now B=1)
|
17 |
+
# TODO generate function : top-p sampling
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class MambaLMConfig(MambaConfig):
|
21 |
+
vocab_size: int = 32000
|
22 |
+
pad_vocab_size_multiple: int = 8
|
23 |
+
|
24 |
+
def __post_init__(self):
|
25 |
+
super().__post_init__()
|
26 |
+
|
27 |
+
#if self.vocab_size % self.pad_vocab_size_multiple != 0:
|
28 |
+
# self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)
|
29 |
+
|
30 |
+
def to_mamba_config(self) -> MambaConfig:
|
31 |
+
mamba_config_fields = {field.name for field in fields(MambaConfig)}
|
32 |
+
filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields}
|
33 |
+
return MambaConfig(**filtered_dict)
|
34 |
+
|
35 |
+
# adapted from https://github.com/johnma2006/mamba-minimal
|
36 |
+
def from_pretrained(name: str):
|
37 |
+
"""
|
38 |
+
Returns a model loaded with pretrained weights pulled from HuggingFace.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
name: As of now, supports
|
42 |
+
* 'state-spaces/mamba-2.8b-slimpj'
|
43 |
+
* 'state-spaces/mamba-2.8b'
|
44 |
+
* 'state-spaces/mamba-1.4b'
|
45 |
+
* 'state-spaces/mamba-790m'
|
46 |
+
* 'state-spaces/mamba-370m'
|
47 |
+
* 'state-spaces/mamba-130m'
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
model: a Mamba model configured with the proper parameters and initialized with the proper weights
|
51 |
+
"""
|
52 |
+
|
53 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
54 |
+
from transformers.utils.hub import cached_file
|
55 |
+
|
56 |
+
def load_config_hf(model_name):
|
57 |
+
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
58 |
+
return json.load(open(resolved_archive_file))
|
59 |
+
|
60 |
+
def load_state_dict_hf(model_name):
|
61 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
62 |
+
return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
|
63 |
+
|
64 |
+
# copy config data
|
65 |
+
config_data = load_config_hf(name)
|
66 |
+
config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
|
67 |
+
|
68 |
+
model = MambaLM(config)
|
69 |
+
|
70 |
+
# copy weights
|
71 |
+
state_dict = load_state_dict_hf(name)
|
72 |
+
|
73 |
+
new_state_dict = {}
|
74 |
+
for key in state_dict:
|
75 |
+
if key == 'backbone.embedding.weight' or key == 'backbone.norm_f.weight':
|
76 |
+
new_key = key.replace('backbone.', '')
|
77 |
+
else:
|
78 |
+
new_key = key.replace('backbone', 'mamba')
|
79 |
+
|
80 |
+
new_state_dict[new_key] = state_dict[key]
|
81 |
+
|
82 |
+
model.load_state_dict(new_state_dict)
|
83 |
+
|
84 |
+
return model
|
85 |
+
|
86 |
+
class MambaLM(nn.Module):
|
87 |
+
def __init__(self, lm_config: MambaLMConfig):
|
88 |
+
super().__init__()
|
89 |
+
self.lm_config = lm_config
|
90 |
+
self.config = lm_config.to_mamba_config()
|
91 |
+
|
92 |
+
self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
|
93 |
+
self.mamba = Mamba(self.config)
|
94 |
+
self.norm_f = RMSNorm(self.config.d_model)
|
95 |
+
|
96 |
+
self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False)
|
97 |
+
self.lm_head.weight = self.embedding.weight
|
98 |
+
|
99 |
+
def forward(self, tokens):
|
100 |
+
# tokens : (B, L)
|
101 |
+
|
102 |
+
# logits : (B, L, vocab_size)
|
103 |
+
|
104 |
+
x = self.embedding(tokens)
|
105 |
+
|
106 |
+
x = self.mamba(x)
|
107 |
+
x = self.norm_f(x)
|
108 |
+
|
109 |
+
logits = self.lm_head(x)
|
110 |
+
|
111 |
+
return logits
|
112 |
+
|
113 |
+
def step(self, token, caches):
|
114 |
+
# token : (B)
|
115 |
+
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
116 |
+
|
117 |
+
# logits : (B, vocab_size)
|
118 |
+
# caches : [cache(layer) for all layers], cache : (h, inputs)
|
119 |
+
|
120 |
+
x = self.embedding(token)
|
121 |
+
|
122 |
+
x, caches = self.mamba.step(x, caches)
|
123 |
+
x = self.norm_f(x)
|
124 |
+
|
125 |
+
logits = self.lm_head(x)
|
126 |
+
|
127 |
+
return logits, caches
|
128 |
+
|
129 |
+
# TODO temperature
|
130 |
+
# TODO process prompt in parallel, and pass in sequential mode when prompt is finished ?
|
131 |
+
def generate(self, tokenizer, prompt: str, num_tokens: int = 50, sample: bool = True, top_k: int = 40):
|
132 |
+
self.eval()
|
133 |
+
|
134 |
+
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(next(self.parameters()).device) # (1, num_tokens)
|
135 |
+
|
136 |
+
# caches is a list of cache, one per layer
|
137 |
+
# cache is composed of : the hidden state, and the last d_conv-1 inputs
|
138 |
+
# the hidden state because the update is like an RNN
|
139 |
+
# the last d_conv-1 inputs because they are used in a 1d convolution (usually d_conv=4 so this is not large)
|
140 |
+
caches = [(None, torch.zeros(1, self.config.d_inner, self.config.d_conv-1, device=input_ids.device)) for _ in range(self.config.n_layers)]
|
141 |
+
|
142 |
+
for i in range(input_ids.size(1) + num_tokens - 1):
|
143 |
+
with torch.no_grad():
|
144 |
+
# forward the new output, get new cache
|
145 |
+
next_token_logits, caches = self.step(input_ids[:, i], caches) # (1, vocab_size), caches
|
146 |
+
|
147 |
+
# sample (no sampling when the prompt is being processed)
|
148 |
+
if i+1 >= input_ids.size(1):
|
149 |
+
probs = F.softmax(next_token_logits, dim=-1) # (1, vocab_size)
|
150 |
+
|
151 |
+
if top_k is not None:
|
152 |
+
values, _ = torch.topk(probs, k=top_k) # (1, k) ordered from lowest to biggest
|
153 |
+
probs[probs < values[:, -1, None]] = 0
|
154 |
+
probs = probs / probs.sum(axis=1, keepdims=True)
|
155 |
+
|
156 |
+
if sample:
|
157 |
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (1)
|
158 |
+
else:
|
159 |
+
next_token = torch.argmax(probs, dim=-1) # (1)
|
160 |
+
|
161 |
+
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)
|
162 |
+
|
163 |
+
output = [tokenizer.decode(output.tolist()) for output in input_ids][0]
|
164 |
+
|
165 |
+
self.train()
|
166 |
+
|
167 |
+
return output
|
168 |
+
|
chess-mamba-vs-xformer/openings.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
chess-mamba-vs-xformer/pscan.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
"""
|
7 |
+
|
8 |
+
An implementation of the parallel scan operation in PyTorch (Blelloch version).
|
9 |
+
Please see docs/pscan.ipynb for a detailed explanation of what happens here.
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
def npo2(len):
|
14 |
+
"""
|
15 |
+
Returns the next power of 2 above len
|
16 |
+
"""
|
17 |
+
|
18 |
+
return 2 ** math.ceil(math.log2(len))
|
19 |
+
|
20 |
+
def pad_npo2(X):
|
21 |
+
"""
|
22 |
+
Pads input length dim to the next power of 2
|
23 |
+
|
24 |
+
Args:
|
25 |
+
X : (B, L, D, N)
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Y : (B, npo2(L), D, N)
|
29 |
+
"""
|
30 |
+
|
31 |
+
len_npo2 = npo2(X.size(1))
|
32 |
+
pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
|
33 |
+
return F.pad(X, pad_tuple, "constant", 0)
|
34 |
+
|
35 |
+
class PScan(torch.autograd.Function):
|
36 |
+
@staticmethod
|
37 |
+
def pscan(A, X):
|
38 |
+
# A : (B, D, L, N)
|
39 |
+
# X : (B, D, L, N)
|
40 |
+
|
41 |
+
# modifies X in place by doing a parallel scan.
|
42 |
+
# more formally, X will be populated by these values :
|
43 |
+
# H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
|
44 |
+
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
|
45 |
+
|
46 |
+
# only supports L that is a power of two (mainly for a clearer code)
|
47 |
+
|
48 |
+
B, D, L, _ = A.size()
|
49 |
+
num_steps = int(math.log2(L))
|
50 |
+
|
51 |
+
# up sweep (last 2 steps unfolded)
|
52 |
+
Aa = A
|
53 |
+
Xa = X
|
54 |
+
for _ in range(num_steps-2):
|
55 |
+
T = Xa.size(2)
|
56 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
57 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
58 |
+
|
59 |
+
Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
|
60 |
+
Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
|
61 |
+
|
62 |
+
Aa = Aa[:, :, :, 1]
|
63 |
+
Xa = Xa[:, :, :, 1]
|
64 |
+
|
65 |
+
# we have only 4, 2 or 1 nodes left
|
66 |
+
if Xa.size(2) == 4:
|
67 |
+
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
|
68 |
+
Aa[:, :, 1].mul_(Aa[:, :, 0])
|
69 |
+
|
70 |
+
Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
|
71 |
+
elif Xa.size(2) == 2:
|
72 |
+
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
|
73 |
+
return
|
74 |
+
else:
|
75 |
+
return
|
76 |
+
|
77 |
+
# down sweep (first 2 steps unfolded)
|
78 |
+
Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
|
79 |
+
Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
|
80 |
+
Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
|
81 |
+
Aa[:, :, 2].mul_(Aa[:, :, 1])
|
82 |
+
|
83 |
+
for k in range(num_steps-3, -1, -1):
|
84 |
+
Aa = A[:, :, 2**k-1:L:2**k]
|
85 |
+
Xa = X[:, :, 2**k-1:L:2**k]
|
86 |
+
|
87 |
+
T = Xa.size(2)
|
88 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
89 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
90 |
+
|
91 |
+
Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
|
92 |
+
Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def pscan_rev(A, X):
|
96 |
+
# A : (B, D, L, N)
|
97 |
+
# X : (B, D, L, N)
|
98 |
+
|
99 |
+
# the same function as above, but in reverse
|
100 |
+
# (if you flip the input, call pscan, then flip the output, you get what this function outputs)
|
101 |
+
# it is used in the backward pass
|
102 |
+
|
103 |
+
# only supports L that is a power of two (mainly for a clearer code)
|
104 |
+
|
105 |
+
B, D, L, _ = A.size()
|
106 |
+
num_steps = int(math.log2(L))
|
107 |
+
|
108 |
+
# up sweep (last 2 steps unfolded)
|
109 |
+
Aa = A
|
110 |
+
Xa = X
|
111 |
+
for _ in range(num_steps-2):
|
112 |
+
T = Xa.size(2)
|
113 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
114 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
115 |
+
|
116 |
+
Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
|
117 |
+
Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
|
118 |
+
|
119 |
+
Aa = Aa[:, :, :, 0]
|
120 |
+
Xa = Xa[:, :, :, 0]
|
121 |
+
|
122 |
+
# we have only 4, 2 or 1 nodes left
|
123 |
+
if Xa.size(2) == 4:
|
124 |
+
Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
|
125 |
+
Aa[:, :, 2].mul_(Aa[:, :, 3])
|
126 |
+
|
127 |
+
Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
|
128 |
+
elif Xa.size(2) == 2:
|
129 |
+
Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
|
130 |
+
return
|
131 |
+
else:
|
132 |
+
return
|
133 |
+
|
134 |
+
# down sweep (first 2 steps unfolded)
|
135 |
+
Aa = A[:, :, 0:L:2**(num_steps-2)]
|
136 |
+
Xa = X[:, :, 0:L:2**(num_steps-2)]
|
137 |
+
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
|
138 |
+
Aa[:, :, 1].mul_(Aa[:, :, 2])
|
139 |
+
|
140 |
+
for k in range(num_steps-3, -1, -1):
|
141 |
+
Aa = A[:, :, 0:L:2**k]
|
142 |
+
Xa = X[:, :, 0:L:2**k]
|
143 |
+
|
144 |
+
T = Xa.size(2)
|
145 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
146 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
147 |
+
|
148 |
+
Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
|
149 |
+
Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def forward(ctx, A_in, X_in):
|
153 |
+
"""
|
154 |
+
Applies the parallel scan operation, as defined above. Returns a new tensor.
|
155 |
+
If you can, privilege sequence lengths that are powers of two.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
A_in : (B, L, D, N)
|
159 |
+
X_in : (B, L, D, N)
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
H : (B, L, D, N)
|
163 |
+
"""
|
164 |
+
|
165 |
+
L = X_in.size(1)
|
166 |
+
|
167 |
+
# cloning is requiered because of the in-place ops
|
168 |
+
if L == npo2(L):
|
169 |
+
A = A_in.clone()
|
170 |
+
X = X_in.clone()
|
171 |
+
else:
|
172 |
+
# pad tensors (and clone btw)
|
173 |
+
A = pad_npo2(A_in) # (B, npo2(L), D, N)
|
174 |
+
X = pad_npo2(X_in) # (B, npo2(L), D, N)
|
175 |
+
|
176 |
+
# prepare tensors
|
177 |
+
A = A.transpose(2, 1) # (B, D, npo2(L), N)
|
178 |
+
X = X.transpose(2, 1) # (B, D, npo2(L), N)
|
179 |
+
|
180 |
+
# parallel scan (modifies X in-place)
|
181 |
+
PScan.pscan(A, X)
|
182 |
+
|
183 |
+
ctx.save_for_backward(A_in, X)
|
184 |
+
|
185 |
+
# slice [:, :L] (cut if there was padding)
|
186 |
+
return X.transpose(2, 1)[:, :L]
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def backward(ctx, grad_output_in):
|
190 |
+
"""
|
191 |
+
Flows the gradient from the output to the input. Returns two new tensors.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
ctx : A_in : (B, L, D, N), X : (B, D, L, N)
|
195 |
+
grad_output_in : (B, L, D, N)
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
gradA : (B, L, D, N), gradX : (B, L, D, N)
|
199 |
+
"""
|
200 |
+
|
201 |
+
A_in, X = ctx.saved_tensors
|
202 |
+
|
203 |
+
L = grad_output_in.size(1)
|
204 |
+
|
205 |
+
# cloning is requiered because of the in-place ops
|
206 |
+
if L == npo2(L):
|
207 |
+
grad_output = grad_output_in.clone()
|
208 |
+
# the next padding will clone A_in
|
209 |
+
else:
|
210 |
+
grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
|
211 |
+
A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
|
212 |
+
|
213 |
+
# prepare tensors
|
214 |
+
grad_output = grad_output.transpose(2, 1)
|
215 |
+
A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
|
216 |
+
A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
|
217 |
+
|
218 |
+
# reverse parallel scan (modifies grad_output in-place)
|
219 |
+
PScan.pscan_rev(A, grad_output)
|
220 |
+
|
221 |
+
Q = torch.zeros_like(X)
|
222 |
+
Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
|
223 |
+
|
224 |
+
return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
|
225 |
+
|
226 |
+
pscan = PScan.apply
|
chess-mamba-vs-xformer/train_bygame.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import math
|
4 |
+
import pickle
|
5 |
+
from contextlib import nullcontext
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
10 |
+
from torch.distributed import init_process_group, destroy_process_group
|
11 |
+
import pyarrow.parquet as pq
|
12 |
+
import random
|
13 |
+
from torch.utils.data import Dataset, DataLoader
|
14 |
+
import glob
|
15 |
+
|
16 |
+
# -----------------------------------------------------------------------------
|
17 |
+
# default config values designed for Mamba model training
|
18 |
+
# I/O
|
19 |
+
out_dir = 'out'
|
20 |
+
eval_interval = 2000
|
21 |
+
log_interval = 1
|
22 |
+
eval_iters = 5
|
23 |
+
eval_only = False
|
24 |
+
always_save_checkpoint = True
|
25 |
+
init_from = 'resume' # 'scratch', 'resume', 'anneal', or Mamba model name
|
26 |
+
# wandb logging
|
27 |
+
wandb_log = False
|
28 |
+
wandb_project = 'mamba'
|
29 |
+
wandb_run_name = 'mamba_run' # modify as needed
|
30 |
+
# data
|
31 |
+
dataset = 'chess' # specify your dataset
|
32 |
+
gradient_accumulation_steps = 5 * 8
|
33 |
+
batch_size = 12
|
34 |
+
base_batch_size = batch_size
|
35 |
+
effective_batch_size = batch_size
|
36 |
+
max_seq_len = 1024 # For xformer, this is the block size
|
37 |
+
train_file_update_interval = 7
|
38 |
+
|
39 |
+
# model
|
40 |
+
model_type = 'mamba'
|
41 |
+
# TODO: add 'xformer' type / model paramers. move model imports to after exec() (when these values finalized)
|
42 |
+
n_layer = 12
|
43 |
+
d_model = 768
|
44 |
+
dt_rank = 'auto'
|
45 |
+
d_state = 16
|
46 |
+
expand_factor = 2
|
47 |
+
bias = False
|
48 |
+
conv_bias = True
|
49 |
+
pscan = True
|
50 |
+
vocab_size = 32
|
51 |
+
move_num_in_gamestate = True
|
52 |
+
# xformer-specific params. Note that n_layer, vocab_size, move_num_in_gamestate, and bias are shared by both model types
|
53 |
+
n_head = 12
|
54 |
+
n_embd = 768
|
55 |
+
dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
|
56 |
+
|
57 |
+
# optimizer settings
|
58 |
+
learning_rate = 6e-4
|
59 |
+
max_iters = 600000 # max_iters is for auto-stopping end of stable phase
|
60 |
+
weight_decay = 1e-1
|
61 |
+
beta1 = 0.9
|
62 |
+
beta2 = 0.95
|
63 |
+
grad_clip = 0.5
|
64 |
+
auto_clip = False
|
65 |
+
auto_clip_max = 0.5
|
66 |
+
auto_clip_min = 3.333e-3
|
67 |
+
grad_clip_start_size = 100
|
68 |
+
grad_clip_max_size = 500
|
69 |
+
grad_clip_percentile = 10
|
70 |
+
# learning rate decay settings
|
71 |
+
decay_lr = True
|
72 |
+
warmup_iters = 2000
|
73 |
+
min_lr = 6e-5
|
74 |
+
# DDP settings
|
75 |
+
backend = 'nccl'
|
76 |
+
# system
|
77 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
78 |
+
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
|
79 |
+
compile = False # set to True if using PyTorch 2.0
|
80 |
+
# -----------------------------------------------------------------------------
|
81 |
+
|
82 |
+
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
83 |
+
exec(open('configurator.py').read()) # overrides from command line or config file
|
84 |
+
config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
85 |
+
# -----------------------------------------------------------------------------
|
86 |
+
|
87 |
+
|
88 |
+
anneal_checkpoint = 'anneal/ckpt.pt'
|
89 |
+
anneal_dir = os.path.join(out_dir, 'anneal/')
|
90 |
+
anneal_start_iters = None # Set at init
|
91 |
+
anneal_decay_iters = None # Set at init
|
92 |
+
|
93 |
+
if model_type == 'mamba':
|
94 |
+
from mamba_lm import MambaLM, MambaLMConfig
|
95 |
+
model_config = MambaLMConfig(
|
96 |
+
d_model=d_model,
|
97 |
+
n_layers=n_layer,
|
98 |
+
dt_rank=dt_rank,
|
99 |
+
d_state=d_state,
|
100 |
+
expand_factor=expand_factor,
|
101 |
+
bias=bias,
|
102 |
+
conv_bias=conv_bias,
|
103 |
+
pscan=pscan,
|
104 |
+
vocab_size=vocab_size
|
105 |
+
)
|
106 |
+
elif model_type == 'xformer':
|
107 |
+
from xformer import GPTConfig, GPT
|
108 |
+
model_config = GPTConfig(
|
109 |
+
n_layer=n_layer,
|
110 |
+
n_head=n_head,
|
111 |
+
n_embd=n_embd,
|
112 |
+
block_size=max_seq_len,
|
113 |
+
bias=bias,
|
114 |
+
vocab_size=vocab_size,
|
115 |
+
dropout=dropout)
|
116 |
+
else:
|
117 |
+
print(f"Unknown model_type {model_type}.")
|
118 |
+
exit()
|
119 |
+
|
120 |
+
# DDP and other initializations
|
121 |
+
ddp = int(os.environ.get('RANK', -1)) != -1
|
122 |
+
if ddp:
|
123 |
+
init_process_group(backend=backend)
|
124 |
+
ddp_rank = int(os.environ['RANK'])
|
125 |
+
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
126 |
+
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
127 |
+
device = f'cuda:{ddp_local_rank}'
|
128 |
+
torch.cuda.set_device(device)
|
129 |
+
master_process = ddp_rank == 0
|
130 |
+
seed_offset = ddp_rank
|
131 |
+
assert gradient_accumulation_steps % ddp_world_size == 0
|
132 |
+
gradient_accumulation_steps //= ddp_world_size
|
133 |
+
else:
|
134 |
+
master_process = True
|
135 |
+
seed_offset = 0
|
136 |
+
ddp_world_size = 1
|
137 |
+
|
138 |
+
if master_process:
|
139 |
+
os.makedirs(out_dir, exist_ok=True)
|
140 |
+
os.makedirs(anneal_dir, exist_ok=True)
|
141 |
+
torch.manual_seed(1337 + seed_offset)
|
142 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
143 |
+
torch.backends.cudnn.allow_tf32 = True
|
144 |
+
device_type = 'cuda' if 'cuda' in device else 'cpu'
|
145 |
+
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16}[dtype]
|
146 |
+
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
147 |
+
|
148 |
+
# poor man's data loader
|
149 |
+
data_dir = os.path.join('data', dataset)
|
150 |
+
current_train_file_index = 0
|
151 |
+
train_files = glob.glob(os.path.join(data_dir, 'train*.parquet'))
|
152 |
+
train_datasets = []
|
153 |
+
for f in train_files:
|
154 |
+
dataset = pq.read_table(f).to_pandas()
|
155 |
+
dataset = dataset[dataset['tokenized'].apply(len) >= 8]
|
156 |
+
train_datasets.append(dataset)
|
157 |
+
#val_data = pq.read_table(os.path.join(data_dir, 'val.parquet')).to_pandas()
|
158 |
+
#val_data = val_data[val_data['tokenized'].apply(len) >= 8]
|
159 |
+
truncated_games_count = 0
|
160 |
+
total_games_count = 0
|
161 |
+
games_seen = 0
|
162 |
+
tokens_seen = 0
|
163 |
+
tokens_seen_padded = 0
|
164 |
+
def get_batch(split):
|
165 |
+
global truncated_games_count, total_games_count, current_train_file_index, tokens_seen, tokens_seen_padded
|
166 |
+
|
167 |
+
# Randomly select batch_size games
|
168 |
+
dataset = train_datasets[current_train_file_index] if split == 'train' else None # else val_data # Use the correct DataFrame based on the split
|
169 |
+
sample_df = dataset.sample(batch_size)
|
170 |
+
games = sample_df['tokenized'].tolist()
|
171 |
+
|
172 |
+
# Prepare sequences tensor for the batch
|
173 |
+
max_length_in_batch = min(max(len(game) for game in games), max_seq_len)
|
174 |
+
pad_to = max_length_in_batch #if model_type == 'mamba' else max_seq_len
|
175 |
+
sequences = torch.zeros((batch_size, pad_to), dtype=torch.int64)
|
176 |
+
|
177 |
+
for i, game in enumerate(games):
|
178 |
+
total_games_count += 1
|
179 |
+
game_len = min(len(game), pad_to)
|
180 |
+
tokens_seen += game_len
|
181 |
+
tokens_seen_padded += pad_to
|
182 |
+
sequences[i, :game_len] = torch.tensor(game[:game_len], dtype=torch.int64)
|
183 |
+
|
184 |
+
if (total_games_count // batch_size) % train_file_update_interval == 0:
|
185 |
+
current_train_file_index = random.randint(0, len(train_files) - 1)
|
186 |
+
# print(f"Switched to file: {train_files[current_train_file_index]}")
|
187 |
+
|
188 |
+
if device_type == 'cuda':
|
189 |
+
sequences = sequences.pin_memory().to(device, non_blocking=True)
|
190 |
+
else:
|
191 |
+
sequences = sequences.to(device)
|
192 |
+
|
193 |
+
return sequences, max_length_in_batch
|
194 |
+
|
195 |
+
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
|
196 |
+
iter_num = 0
|
197 |
+
best_val_loss = 1e9
|
198 |
+
|
199 |
+
# attempt to derive vocab_size from the dataset
|
200 |
+
meta_path = os.path.join(data_dir, 'meta.pkl')
|
201 |
+
meta_vocab_size = None
|
202 |
+
if not move_num_in_gamestate:
|
203 |
+
meta_vocab_size = 28
|
204 |
+
elif os.path.exists(meta_path):
|
205 |
+
with open(meta_path, 'rb') as f:
|
206 |
+
meta = pickle.load(f)
|
207 |
+
meta_vocab_size = meta['vocab_size']
|
208 |
+
print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
|
209 |
+
|
210 |
+
# Model initialization
|
211 |
+
if init_from == 'scratch':
|
212 |
+
print(f"Initializing a new {model_type} model from scratch")
|
213 |
+
if meta_vocab_size is None:
|
214 |
+
print(f"defaulting to vocab_size of {vocab_size}")
|
215 |
+
else:
|
216 |
+
model_config.vocab_size = meta_vocab_size
|
217 |
+
if model_type == 'mamba':
|
218 |
+
model = MambaLM(model_config)
|
219 |
+
else:
|
220 |
+
model = GPT(model_config)
|
221 |
+
if auto_clip:
|
222 |
+
grad_clip = 0
|
223 |
+
config['grad_clip'] = 0
|
224 |
+
grad_norm_history = []
|
225 |
+
elif init_from == 'resume' or init_from == 'anneal':
|
226 |
+
print(f"Resuming training from {out_dir}")
|
227 |
+
if init_from == 'anneal':
|
228 |
+
ckpt_path = os.path.join(out_dir, anneal_checkpoint)
|
229 |
+
else:
|
230 |
+
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
231 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
232 |
+
model_config = checkpoint['model_args']
|
233 |
+
if model_type == 'mamba':
|
234 |
+
model = MambaLM(model_config)
|
235 |
+
else:
|
236 |
+
model = GPT(model_config)
|
237 |
+
state_dict = checkpoint['model']
|
238 |
+
# fix the keys of the state dictionary :(
|
239 |
+
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
|
240 |
+
unwanted_prefix = '_orig_mod.'
|
241 |
+
for k,v in list(state_dict.items()):
|
242 |
+
if k.startswith(unwanted_prefix):
|
243 |
+
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
244 |
+
model.load_state_dict(state_dict)
|
245 |
+
if 'effective_batch_size' not in checkpoint['config']:
|
246 |
+
print("Checkpoint was saved without `effective_batch_size`, assuming current value (will save with next checkpoint). This is used for correcting `iter_num` when the effetive batch size is changed.")
|
247 |
+
checkpoint['config']['effective_batch_size'] = effective_batch_size
|
248 |
+
iter_num = int(round(checkpoint['iter_num'] * (checkpoint['config']['effective_batch_size'] / effective_batch_size)))
|
249 |
+
if 'games_seen' in checkpoint:
|
250 |
+
games_seen = checkpoint['games_seen']
|
251 |
+
else:
|
252 |
+
games_seen = checkpoint['config']['effective_batch_size'] * checkpoint['iter_num']
|
253 |
+
checkpoint['games_seen'] = games_seen
|
254 |
+
print(f"Checkpoint was saved without `games_seen`, assuming checkpoint's effective batch size * iters (will save with next checkpoint). {games_seen}")
|
255 |
+
tokens_seen = checkpoint.get('tokens_seen', 0)
|
256 |
+
tokens_seen_padded = checkpoint.get('tokens_seen_padded', 0)
|
257 |
+
best_val_loss = checkpoint['best_val_loss']
|
258 |
+
print(f"Best val loss: {best_val_loss}")
|
259 |
+
if auto_clip:
|
260 |
+
grad_clip = checkpoint['config']['grad_clip']
|
261 |
+
config['grad_clip'] = grad_clip
|
262 |
+
#grad_norm_history = [t.item() if torch.is_tensor(t) else t for t in checkpoint.get('grad_norm_history', [])]
|
263 |
+
grad_norm_history = checkpoint.get('grad_norm_history', [])
|
264 |
+
if init_from == 'anneal':
|
265 |
+
print(f"\n\nANNEAL STARTING/RESUMING FROM ITERNUM: {iter_num} ({games_seen} games)\n\n")
|
266 |
+
anneal_start_iters = iter_num if 'anneal_start_iters' not in checkpoint else checkpoint['anneal_start_iters']
|
267 |
+
anneal_decay_iters = iter_num / 8 if 'anneal_decay_iters' not in checkpoint else checkpoint['anneal_decay_iters'] # / 9 is og, but going deeper on lr too (can always take earlier ckpt during anneal if it doesn't keep improving)... have used 6.75
|
268 |
+
print(anneal_start_iters)
|
269 |
+
print(anneal_decay_iters)
|
270 |
+
if 'anneal_start_iters' not in checkpoint:
|
271 |
+
grad_clip = 0
|
272 |
+
config['grad_clip'] = 0
|
273 |
+
grad_norm_history = []
|
274 |
+
print(f"Starting anneal. Resumed from {anneal_checkpoint}, will now decay learning rate for {anneal_decay_iters} / until iter_num {anneal_start_iters + anneal_decay_iters}.")
|
275 |
+
out_dir = anneal_dir
|
276 |
+
weight_decay = weight_decay / 12.5 # / 17.0
|
277 |
+
beta2 = np.sqrt(beta2) * beta2
|
278 |
+
auto_clip = True
|
279 |
+
grad_clip_percentile = 6.75
|
280 |
+
elif init_from.startswith('state-spaces'):
|
281 |
+
print(f"Initializing from Mamba pre-trained weights: {init_from}")
|
282 |
+
model = from_pretrained(init_from)
|
283 |
+
model_config = model.config
|
284 |
+
else:
|
285 |
+
raise ValueError("Invalid init_from value")
|
286 |
+
|
287 |
+
model.to(device)
|
288 |
+
|
289 |
+
print(f'Model with {sum([p.numel() for p in model.parameters()])} parameters loaded.')
|
290 |
+
|
291 |
+
# Optimizer and GradScaler
|
292 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2))
|
293 |
+
scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16')
|
294 |
+
if init_from == 'resume':
|
295 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
296 |
+
checkpoint = None
|
297 |
+
|
298 |
+
# Compile the model if using PyTorch 2.0
|
299 |
+
if compile:
|
300 |
+
print("compiling the model... (takes a ~minute)")
|
301 |
+
model = torch.compile(model)
|
302 |
+
|
303 |
+
# Wrap model in DDP container if necessary
|
304 |
+
if ddp:
|
305 |
+
model = DDP(model, device_ids=[ddp_local_rank])
|
306 |
+
|
307 |
+
|
308 |
+
def batch_to_loss(sequences, max_length_in_batch):
|
309 |
+
if model_type == 'mamba':
|
310 |
+
logits = model(sequences[:, :-1]) # Forward pass, exclude last token for input
|
311 |
+
# Compute loss (assuming next token prediction task)
|
312 |
+
targets = sequences[:, 1:].reshape(-1) # Shifted by one for next token prediction
|
313 |
+
return F.cross_entropy(logits.view(-1, logits.size(-1)), targets)
|
314 |
+
else:
|
315 |
+
inputs = sequences[:, :-1]
|
316 |
+
targets = sequences[:, 1:].reshape(-1)
|
317 |
+
_, loss = model(inputs, targets)
|
318 |
+
return loss
|
319 |
+
|
320 |
+
|
321 |
+
@torch.no_grad()
|
322 |
+
def estimate_loss():
|
323 |
+
global tokens_seen, tokens_seen_padded
|
324 |
+
out = {}
|
325 |
+
model.eval()
|
326 |
+
tokens_seen_b4 = tokens_seen
|
327 |
+
tokens_seen_padded_b4 = tokens_seen_padded
|
328 |
+
for split in ['train']: #['train', 'val']:
|
329 |
+
losses = torch.zeros(eval_iters)
|
330 |
+
for k in range(eval_iters):
|
331 |
+
loss = batch_to_loss(*get_batch(split))
|
332 |
+
losses[k] = loss.item()
|
333 |
+
|
334 |
+
split = 'val' # Temporary hack
|
335 |
+
out[split] = losses.mean()
|
336 |
+
tokens_seen = tokens_seen_b4
|
337 |
+
tokens_seen_padded = tokens_seen_padded_b4
|
338 |
+
model.train()
|
339 |
+
return out
|
340 |
+
|
341 |
+
|
342 |
+
# WSD scheduler
|
343 |
+
def get_lr(it):
|
344 |
+
if init_from == 'anneal':
|
345 |
+
# Linear decay from max LR to min LR over (anneal_start_iters / 9) iters
|
346 |
+
decay_ratio = min(it - anneal_start_iters, anneal_decay_iters) / anneal_decay_iters
|
347 |
+
return learning_rate - decay_ratio * (learning_rate - min_lr)
|
348 |
+
|
349 |
+
if it < warmup_iters:
|
350 |
+
# Warmup
|
351 |
+
return learning_rate * it / warmup_iters
|
352 |
+
|
353 |
+
# Stable max LR
|
354 |
+
return learning_rate
|
355 |
+
|
356 |
+
# Logging setup
|
357 |
+
if wandb_log and master_process:
|
358 |
+
import wandb
|
359 |
+
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
360 |
+
|
361 |
+
# Training loop
|
362 |
+
local_iter_num = 0 # Number of iterations in the lifetime of this process
|
363 |
+
last_crossed_multiple = 0
|
364 |
+
save_every_n_games = 150000
|
365 |
+
raw_model = model.module if ddp else model # Unwrap DDP container if needed
|
366 |
+
|
367 |
+
# initial save
|
368 |
+
if init_from == 'scratch':
|
369 |
+
checkpoint = {
|
370 |
+
'model': raw_model.state_dict(),
|
371 |
+
'optimizer': optimizer.state_dict(),
|
372 |
+
'model_args': model_config,
|
373 |
+
'iter_num': 0,
|
374 |
+
"games_seen": 0,
|
375 |
+
"tokens_seen": 0,
|
376 |
+
"tokens_seen_padded": 0,
|
377 |
+
'best_val_loss': best_val_loss,
|
378 |
+
'config': config,
|
379 |
+
}
|
380 |
+
checkpoint['grad_norm_history'] = grad_norm_history
|
381 |
+
print(f"saving checkpoint to {out_dir}\n")
|
382 |
+
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
|
383 |
+
|
384 |
+
t0 = time.time()
|
385 |
+
while True:
|
386 |
+
# Determine and set the learning rate for this iteration
|
387 |
+
lr = get_lr(iter_num) if decay_lr else learning_rate
|
388 |
+
for param_group in optimizer.param_groups:
|
389 |
+
param_group['lr'] = lr
|
390 |
+
|
391 |
+
# Evaluate the loss on train/val sets and write checkpoints
|
392 |
+
if iter_num % eval_interval == 0 and master_process and local_iter_num > 0:
|
393 |
+
torch.cuda.empty_cache()
|
394 |
+
losses = estimate_loss()
|
395 |
+
print(f"\ngame {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): 'val' loss {losses['val']:.4f}")
|
396 |
+
if auto_clip and len(grad_norm_history) >= grad_clip_start_size:
|
397 |
+
grad_clip_prev = grad_clip
|
398 |
+
grad_clip = np.percentile(grad_norm_history, grad_clip_percentile)
|
399 |
+
grad_clip = max(min(grad_clip, auto_clip_max), auto_clip_min)
|
400 |
+
# Transition between grad_clips smoothly, weighed to new value
|
401 |
+
grad_clip = (grad_clip*9.0 + grad_clip_prev*4.0) / 13.0
|
402 |
+
grad_clip = max(min(grad_clip, auto_clip_max), auto_clip_min) # should never actually clip here
|
403 |
+
config['grad_clip'] = grad_clip
|
404 |
+
print(f"Auto adjusted grad_clip to {grad_clip}")
|
405 |
+
torch.cuda.empty_cache()
|
406 |
+
if wandb_log:
|
407 |
+
wandb.log({
|
408 |
+
"etc/iter": iter_num,
|
409 |
+
"etc/games": games_seen,
|
410 |
+
"etc/tokens_seen": tokens_seen,
|
411 |
+
"etc/tokens_seen_padded": tokens_seen_padded,
|
412 |
+
"etc/grad_clip": grad_clip,
|
413 |
+
"etc/lr": lr,
|
414 |
+
"val/loss": losses['val'],
|
415 |
+
|
416 |
+
})
|
417 |
+
if losses['val'] < best_val_loss or always_save_checkpoint:
|
418 |
+
if iter_num > 0:
|
419 |
+
checkpoint = {
|
420 |
+
'model': raw_model.state_dict(),
|
421 |
+
'optimizer': optimizer.state_dict(),
|
422 |
+
'model_args': model_config,
|
423 |
+
'iter_num': iter_num,
|
424 |
+
"games_seen": games_seen,
|
425 |
+
"tokens_seen": tokens_seen,
|
426 |
+
"tokens_seen_padded": tokens_seen_padded,
|
427 |
+
'best_val_loss': min(best_val_loss, losses['val']),
|
428 |
+
'config': config,
|
429 |
+
}
|
430 |
+
checkpoint['grad_norm_history'] = grad_norm_history
|
431 |
+
if init_from == 'anneal':
|
432 |
+
checkpoint['anneal_start_iters'] = anneal_start_iters
|
433 |
+
checkpoint['anneal_decay_iters'] = anneal_decay_iters
|
434 |
+
print(f"saving checkpoint to {out_dir}\n")
|
435 |
+
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
|
436 |
+
current_nearest_multiple = (games_seen // save_every_n_games) * save_every_n_games
|
437 |
+
if losses['val'] < best_val_loss: # Temporary / only good after it's settled
|
438 |
+
best_val_loss = losses['val']
|
439 |
+
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt'))
|
440 |
+
elif current_nearest_multiple != last_crossed_multiple: # elif so we don't double up
|
441 |
+
last_crossed_multiple = current_nearest_multiple
|
442 |
+
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}.pt'))
|
443 |
+
|
444 |
+
if iter_num == 0 and eval_only:
|
445 |
+
break
|
446 |
+
|
447 |
+
# Forward and backward pass
|
448 |
+
for micro_step in range(gradient_accumulation_steps):
|
449 |
+
if ddp:
|
450 |
+
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
|
451 |
+
|
452 |
+
sequences, max_length_in_batch = get_batch('train') # Fetch the training data
|
453 |
+
with ctx:
|
454 |
+
loss = batch_to_loss(sequences, max_length_in_batch)
|
455 |
+
loss = loss / gradient_accumulation_steps
|
456 |
+
|
457 |
+
scaler.scale(loss).backward()
|
458 |
+
#print('.', end='')
|
459 |
+
|
460 |
+
# clip the gradient
|
461 |
+
if grad_clip != 0.0 or auto_clip:
|
462 |
+
scaler.unscale_(optimizer)
|
463 |
+
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 999.9) # The 0 check is for auto_clip enabled but not enough history
|
464 |
+
grad_norm_history.append(total_norm.item())
|
465 |
+
grad_norm_history = grad_norm_history[-grad_clip_max_size:]
|
466 |
+
|
467 |
+
# step the optimizer and scaler if training in fp16
|
468 |
+
scaler.step(optimizer)
|
469 |
+
scaler.update()
|
470 |
+
# flush the gradients as soon as we can, no need for this memory anymore
|
471 |
+
optimizer.zero_grad(set_to_none=True)
|
472 |
+
torch.cuda.empty_cache()
|
473 |
+
|
474 |
+
# timing and logging
|
475 |
+
t1 = time.time()
|
476 |
+
dt = t1 - t0
|
477 |
+
t0 = t1
|
478 |
+
if iter_num % log_interval == 0 and master_process:
|
479 |
+
# get loss as float. note: this is a CPU-GPU sync point
|
480 |
+
# scale up to undo the division above, approximating the true total loss (exact would have been a sum)
|
481 |
+
lossf = loss.item() * gradient_accumulation_steps
|
482 |
+
print(f"game {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): loss {lossf:.4f}, time {dt*1000:.2f}ms")
|
483 |
+
if wandb_log:
|
484 |
+
wandb.log({
|
485 |
+
"etc/iter": iter_num,
|
486 |
+
"etc/games": games_seen,
|
487 |
+
"etc/tokens_seen": tokens_seen,
|
488 |
+
"etc/tokens_seen_padded": tokens_seen_padded,
|
489 |
+
"etc/grad_norm": grad_norm_history[-1] if grad_norm_history else 0,
|
490 |
+
"etc/lr": lr,
|
491 |
+
"train/loss": lossf,
|
492 |
+
})
|
493 |
+
iter_num += 1
|
494 |
+
local_iter_num += 1
|
495 |
+
games_seen += effective_batch_size
|
496 |
+
|
497 |
+
# termination conditions
|
498 |
+
if iter_num > max_iters and not init_from == 'anneal': # max iters is for auto-stopping end of stable phase
|
499 |
+
checkpoint = {
|
500 |
+
'model': raw_model.state_dict(),
|
501 |
+
'optimizer': optimizer.state_dict(),
|
502 |
+
'model_args': model_config,
|
503 |
+
'iter_num': iter_num,
|
504 |
+
"games_seen": games_seen,
|
505 |
+
"tokens_seen": tokens_seen,
|
506 |
+
"tokens_seen_padded": tokens_seen_padded,
|
507 |
+
'best_val_loss': best_val_loss,
|
508 |
+
'config': config,
|
509 |
+
}
|
510 |
+
checkpoint['grad_norm_history'] = grad_norm_history
|
511 |
+
if init_from == 'anneal':
|
512 |
+
checkpoint['anneal_start_iters'] = anneal_start_iters
|
513 |
+
checkpoint['anneal_decay_iters'] = anneal_decay_iters
|
514 |
+
print(f"Max_iters reached. Saving pre-anneal checkpoint to {anneal_checkpoint}")
|
515 |
+
torch.save(checkpoint, os.path.join(out_dir, anneal_checkpoint))
|
516 |
+
break
|
517 |
+
if init_from == 'anneal' and iter_num >= anneal_start_iters + anneal_decay_iters:
|
518 |
+
checkpoint = {
|
519 |
+
'model': raw_model.state_dict(),
|
520 |
+
'optimizer': optimizer.state_dict(),
|
521 |
+
'model_args': model_config,
|
522 |
+
'iter_num': iter_num,
|
523 |
+
"games_seen": games_seen,
|
524 |
+
"tokens_seen": tokens_seen,
|
525 |
+
"tokens_seen_padded": tokens_seen_padded,
|
526 |
+
'best_val_loss': best_val_loss,
|
527 |
+
'config': config,
|
528 |
+
}
|
529 |
+
checkpoint['grad_norm_history'] = grad_norm_history
|
530 |
+
if init_from == 'anneal':
|
531 |
+
checkpoint['anneal_start_iters'] = anneal_start_iters
|
532 |
+
checkpoint['anneal_decay_iters'] = anneal_decay_iters
|
533 |
+
print(f"Anneal complete. Saving checkpoint to {out_dir}")
|
534 |
+
torch.save(checkpoint, os.path.join(out_dir, 'anneal_complete.pt'))
|
535 |
+
break
|
536 |
+
|
537 |
+
|
538 |
+
|
539 |
+
if ddp:
|
540 |
+
destroy_process_group()
|
541 |
+
|
chess-mamba-vs-xformer/xformer.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Full definition of a GPT Language Model, all of it in this single file.
|
3 |
+
References:
|
4 |
+
1) the official GPT-2 TensorFlow implementation released by OpenAI:
|
5 |
+
https://github.com/openai/gpt-2/blob/master/src/model.py
|
6 |
+
2) huggingface/transformers PyTorch implementation:
|
7 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
8 |
+
"""
|
9 |
+
|
10 |
+
import math
|
11 |
+
import inspect
|
12 |
+
from dataclasses import dataclass
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
class LayerNorm(nn.Module):
|
19 |
+
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
20 |
+
|
21 |
+
def __init__(self, ndim, bias):
|
22 |
+
super().__init__()
|
23 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
24 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
25 |
+
|
26 |
+
def forward(self, input):
|
27 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
28 |
+
|
29 |
+
class CausalSelfAttention(nn.Module):
|
30 |
+
|
31 |
+
def __init__(self, config):
|
32 |
+
super().__init__()
|
33 |
+
assert config.n_embd % config.n_head == 0
|
34 |
+
# key, query, value projections for all heads, but in a batch
|
35 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
36 |
+
# output projection
|
37 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
38 |
+
# regularization
|
39 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
40 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
41 |
+
self.n_head = config.n_head
|
42 |
+
self.n_embd = config.n_embd
|
43 |
+
self.dropout = config.dropout
|
44 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
45 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
46 |
+
if not self.flash:
|
47 |
+
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
48 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
49 |
+
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
50 |
+
.view(1, 1, config.block_size, config.block_size))
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
54 |
+
|
55 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
56 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
57 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
58 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
59 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
60 |
+
|
61 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
62 |
+
if self.flash:
|
63 |
+
# efficient attention using Flash Attention CUDA kernels
|
64 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
|
65 |
+
else:
|
66 |
+
# manual implementation of attention
|
67 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
68 |
+
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
69 |
+
att = F.softmax(att, dim=-1)
|
70 |
+
att = self.attn_dropout(att)
|
71 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
72 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
73 |
+
|
74 |
+
# output projection
|
75 |
+
y = self.resid_dropout(self.c_proj(y))
|
76 |
+
return y
|
77 |
+
|
78 |
+
class MLP(nn.Module):
|
79 |
+
|
80 |
+
def __init__(self, config):
|
81 |
+
super().__init__()
|
82 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
83 |
+
self.gelu = nn.GELU()
|
84 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
85 |
+
self.dropout = nn.Dropout(config.dropout)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x = self.c_fc(x)
|
89 |
+
x = self.gelu(x)
|
90 |
+
x = self.c_proj(x)
|
91 |
+
x = self.dropout(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
class Block(nn.Module):
|
95 |
+
|
96 |
+
def __init__(self, config):
|
97 |
+
super().__init__()
|
98 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
99 |
+
self.attn = CausalSelfAttention(config)
|
100 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
101 |
+
self.mlp = MLP(config)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
x = x + self.attn(self.ln_1(x))
|
105 |
+
x = x + self.mlp(self.ln_2(x))
|
106 |
+
return x
|
107 |
+
|
108 |
+
@dataclass
|
109 |
+
class GPTConfig:
|
110 |
+
block_size: int = 1024
|
111 |
+
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
|
112 |
+
n_layer: int = 12
|
113 |
+
n_head: int = 12
|
114 |
+
n_embd: int = 768
|
115 |
+
dropout: float = 0.0
|
116 |
+
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
117 |
+
|
118 |
+
class GPT(nn.Module):
|
119 |
+
|
120 |
+
def __init__(self, config):
|
121 |
+
super().__init__()
|
122 |
+
assert config.vocab_size is not None
|
123 |
+
assert config.block_size is not None
|
124 |
+
self.config = config
|
125 |
+
|
126 |
+
self.transformer = nn.ModuleDict(dict(
|
127 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
128 |
+
wpe = nn.Embedding(config.block_size, config.n_embd),
|
129 |
+
drop = nn.Dropout(config.dropout),
|
130 |
+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
131 |
+
ln_f = LayerNorm(config.n_embd, bias=config.bias),
|
132 |
+
))
|
133 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
134 |
+
# with weight tying when using torch.compile() some warnings get generated:
|
135 |
+
# "UserWarning: functional_call was passed multiple values for tied weights.
|
136 |
+
# This behavior is deprecated and will be an error in future versions"
|
137 |
+
# not 100% sure what this is, so far seems to be harmless. TODO investigate
|
138 |
+
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
139 |
+
|
140 |
+
# init all weights
|
141 |
+
self.apply(self._init_weights)
|
142 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
143 |
+
for pn, p in self.named_parameters():
|
144 |
+
if pn.endswith('c_proj.weight'):
|
145 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
146 |
+
|
147 |
+
# report number of parameters
|
148 |
+
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
|
149 |
+
|
150 |
+
def get_num_params(self, non_embedding=True):
|
151 |
+
"""
|
152 |
+
Return the number of parameters in the model.
|
153 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
154 |
+
The token embeddings would too, except due to the parameter sharing these
|
155 |
+
params are actually used as weights in the final layer, so we include them.
|
156 |
+
"""
|
157 |
+
n_params = sum(p.numel() for p in self.parameters())
|
158 |
+
if non_embedding:
|
159 |
+
n_params -= self.transformer.wpe.weight.numel()
|
160 |
+
return n_params
|
161 |
+
|
162 |
+
def _init_weights(self, module):
|
163 |
+
if isinstance(module, nn.Linear):
|
164 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
165 |
+
if module.bias is not None:
|
166 |
+
torch.nn.init.zeros_(module.bias)
|
167 |
+
elif isinstance(module, nn.Embedding):
|
168 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
169 |
+
|
170 |
+
def forward(self, idx, targets=None):
|
171 |
+
device = idx.device
|
172 |
+
b, t = idx.size()
|
173 |
+
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
174 |
+
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
|
175 |
+
|
176 |
+
# forward the GPT model itself
|
177 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
178 |
+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
179 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
180 |
+
for block in self.transformer.h:
|
181 |
+
x = block(x)
|
182 |
+
x = self.transformer.ln_f(x)
|
183 |
+
|
184 |
+
if targets is not None:
|
185 |
+
# if we are given some desired targets also calculate the loss
|
186 |
+
logits = self.lm_head(x)
|
187 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
188 |
+
else:
|
189 |
+
# inference-time mini-optimization: only forward the lm_head on the very last position
|
190 |
+
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
191 |
+
loss = None
|
192 |
+
|
193 |
+
return logits, loss
|
194 |
+
|
195 |
+
def crop_block_size(self, block_size):
|
196 |
+
# model surgery to decrease the block size if necessary
|
197 |
+
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
|
198 |
+
# but want to use a smaller block size for some smaller, simpler model
|
199 |
+
assert block_size <= self.config.block_size
|
200 |
+
self.config.block_size = block_size
|
201 |
+
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
|
202 |
+
for block in self.transformer.h:
|
203 |
+
if hasattr(block.attn, 'bias'):
|
204 |
+
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
|
205 |
+
|
206 |
+
@classmethod
|
207 |
+
def from_pretrained(cls, model_type, override_args=None):
|
208 |
+
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
|
209 |
+
override_args = override_args or {} # default to empty dict
|
210 |
+
# only dropout can be overridden see more notes below
|
211 |
+
assert all(k == 'dropout' for k in override_args)
|
212 |
+
from transformers import GPT2LMHeadModel
|
213 |
+
print("loading weights from pretrained gpt: %s" % model_type)
|
214 |
+
|
215 |
+
# n_layer, n_head and n_embd are determined from model_type
|
216 |
+
config_args = {
|
217 |
+
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
|
218 |
+
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
|
219 |
+
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
|
220 |
+
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
|
221 |
+
}[model_type]
|
222 |
+
print("forcing vocab_size=50257, block_size=1024, bias=True")
|
223 |
+
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
|
224 |
+
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
|
225 |
+
config_args['bias'] = True # always True for GPT model checkpoints
|
226 |
+
# we can override the dropout rate, if desired
|
227 |
+
if 'dropout' in override_args:
|
228 |
+
print(f"overriding dropout rate to {override_args['dropout']}")
|
229 |
+
config_args['dropout'] = override_args['dropout']
|
230 |
+
# create a from-scratch initialized minGPT model
|
231 |
+
config = GPTConfig(**config_args)
|
232 |
+
model = GPT(config)
|
233 |
+
sd = model.state_dict()
|
234 |
+
sd_keys = sd.keys()
|
235 |
+
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
|
236 |
+
|
237 |
+
# init a huggingface/transformers model
|
238 |
+
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
239 |
+
sd_hf = model_hf.state_dict()
|
240 |
+
|
241 |
+
# copy while ensuring all of the parameters are aligned and match in names and shapes
|
242 |
+
sd_keys_hf = sd_hf.keys()
|
243 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
|
244 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
|
245 |
+
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
246 |
+
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
|
247 |
+
# this means that we have to transpose these weights when we import them
|
248 |
+
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
|
249 |
+
for k in sd_keys_hf:
|
250 |
+
if any(k.endswith(w) for w in transposed):
|
251 |
+
# special treatment for the Conv1D weights we need to transpose
|
252 |
+
assert sd_hf[k].shape[::-1] == sd[k].shape
|
253 |
+
with torch.no_grad():
|
254 |
+
sd[k].copy_(sd_hf[k].t())
|
255 |
+
else:
|
256 |
+
# vanilla copy over the other parameters
|
257 |
+
assert sd_hf[k].shape == sd[k].shape
|
258 |
+
with torch.no_grad():
|
259 |
+
sd[k].copy_(sd_hf[k])
|
260 |
+
|
261 |
+
return model
|
262 |
+
|
263 |
+
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
264 |
+
# start with all of the candidate parameters
|
265 |
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
266 |
+
# filter out those that do not require grad
|
267 |
+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
268 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
269 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
270 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
271 |
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
272 |
+
optim_groups = [
|
273 |
+
{'params': decay_params, 'weight_decay': weight_decay},
|
274 |
+
{'params': nodecay_params, 'weight_decay': 0.0}
|
275 |
+
]
|
276 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
277 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
278 |
+
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
279 |
+
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
280 |
+
# Create AdamW optimizer and use the fused version if it is available
|
281 |
+
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
282 |
+
use_fused = fused_available and device_type == 'cuda'
|
283 |
+
extra_args = dict(fused=True) if use_fused else dict()
|
284 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
285 |
+
print(f"using fused AdamW: {use_fused}")
|
286 |
+
|
287 |
+
return optimizer
|
288 |
+
|
289 |
+
def estimate_mfu(self, fwdbwd_per_iter, dt):
|
290 |
+
""" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
|
291 |
+
# first estimate the number of flops we do per iteration.
|
292 |
+
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
|
293 |
+
N = self.get_num_params()
|
294 |
+
cfg = self.config
|
295 |
+
L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
|
296 |
+
flops_per_token = 6*N + 12*L*H*Q*T
|
297 |
+
flops_per_fwdbwd = flops_per_token * T
|
298 |
+
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
299 |
+
# express our flops throughput as ratio of A100 bfloat16 peak flops
|
300 |
+
flops_achieved = flops_per_iter * (1.0/dt) # per second
|
301 |
+
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
302 |
+
mfu = flops_achieved / flops_promised
|
303 |
+
return mfu
|
304 |
+
|
305 |
+
@torch.no_grad()
|
306 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
307 |
+
"""
|
308 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
309 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
310 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
311 |
+
"""
|
312 |
+
for _ in range(max_new_tokens):
|
313 |
+
# if the sequence context is growing too long we must crop it at block_size
|
314 |
+
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
315 |
+
# forward the model to get the logits for the index in the sequence
|
316 |
+
logits, _ = self(idx_cond)
|
317 |
+
# pluck the logits at the final step and scale by desired temperature
|
318 |
+
logits = logits[:, -1, :] / temperature
|
319 |
+
# optionally crop the logits to only the top k options
|
320 |
+
if top_k is not None:
|
321 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
322 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
323 |
+
# apply softmax to convert logits to (normalized) probabilities
|
324 |
+
probs = F.softmax(logits, dim=-1)
|
325 |
+
# sample from the distribution
|
326 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
327 |
+
# append sampled index to the running sequence and continue
|
328 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
329 |
+
|
330 |
+
return idx
|