HaileyStorm
commited on
Upload chess-mamba-vs-xformer/train_bygame.py with huggingface_hub
Browse files
chess-mamba-vs-xformer/train_bygame.py
CHANGED
@@ -92,17 +92,22 @@ anneal_decay_iters = None # Set at init
|
|
92 |
|
93 |
if model_type == 'mamba':
|
94 |
from mamba_lm import MambaLM, MambaLMConfig
|
|
|
95 |
model_config = MambaLMConfig(
|
96 |
d_model=d_model,
|
97 |
-
n_layers=n_layer,
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
106 |
elif model_type == 'xformer':
|
107 |
from xformer import GPTConfig, GPT
|
108 |
model_config = GPTConfig(
|
@@ -152,10 +157,13 @@ train_files = glob.glob(os.path.join(data_dir, 'train*.parquet')) + \
|
|
152 |
glob.glob(os.path.join(data_dir, 'stable*.parquet')) + \
|
153 |
glob.glob(os.path.join(data_dir, 'anneal*.parquet'))
|
154 |
train_datasets = []
|
|
|
155 |
for f in train_files:
|
156 |
dataset = pq.read_table(f).to_pandas()
|
157 |
dataset = dataset[dataset['tokenized'].apply(len) >= 8]
|
158 |
train_datasets.append(dataset)
|
|
|
|
|
159 |
#val_data = pq.read_table(os.path.join(data_dir, 'val.parquet')).to_pandas()
|
160 |
#val_data = val_data[val_data['tokenized'].apply(len) >= 8]
|
161 |
truncated_games_count = 0
|
@@ -217,7 +225,8 @@ if init_from == 'scratch':
|
|
217 |
else:
|
218 |
model_config.vocab_size = meta_vocab_size
|
219 |
if model_type == 'mamba':
|
220 |
-
model = MambaLM(model_config)
|
|
|
221 |
else:
|
222 |
model = GPT(model_config)
|
223 |
if auto_clip:
|
@@ -233,7 +242,8 @@ elif init_from == 'resume' or init_from == 'anneal':
|
|
233 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
234 |
model_config = checkpoint['model_args']
|
235 |
if model_type == 'mamba':
|
236 |
-
model = MambaLM(model_config)
|
|
|
237 |
else:
|
238 |
model = GPT(model_config)
|
239 |
state_dict = checkpoint['model']
|
@@ -309,10 +319,11 @@ if ddp:
|
|
309 |
|
310 |
def batch_to_loss(sequences, max_length_in_batch):
|
311 |
if model_type == 'mamba':
|
312 |
-
logits = model(sequences[:, :-1]) # Forward pass, exclude last token for input
|
313 |
# Compute loss (assuming next token prediction task)
|
314 |
targets = sequences[:, 1:].reshape(-1) # Shifted by one for next token prediction
|
315 |
return F.cross_entropy(logits.view(-1, logits.size(-1)), targets)
|
|
|
316 |
else:
|
317 |
inputs = sequences[:, :-1]
|
318 |
targets = sequences[:, 1:].reshape(-1)
|
@@ -474,7 +485,7 @@ while True:
|
|
474 |
scaler.update()
|
475 |
# flush the gradients as soon as we can, no need for this memory anymore
|
476 |
optimizer.zero_grad(set_to_none=True)
|
477 |
-
torch.cuda.empty_cache()
|
478 |
|
479 |
# timing and logging
|
480 |
t1 = time.time()
|
|
|
92 |
|
93 |
if model_type == 'mamba':
|
94 |
from mamba_lm import MambaLM, MambaLMConfig
|
95 |
+
from mamba_ssm import MambaLMHeadModel
|
96 |
model_config = MambaLMConfig(
|
97 |
d_model=d_model,
|
98 |
+
#n_layers=n_layer,
|
99 |
+
n_layer=n_layer,
|
100 |
+
ssm_cfg={
|
101 |
+
'dt_rank': dt_rank,
|
102 |
+
'd_state': d_state,
|
103 |
+
#'expand_factor': expand_factor,
|
104 |
+
'bias': bias,
|
105 |
+
'conv_bias':conv_bias,
|
106 |
+
#'pscan':pscan,
|
107 |
+
},
|
108 |
+
vocab_size=vocab_size,
|
109 |
+
pad_vocab_size_multiple=1
|
110 |
+
).to_mamba_config()
|
111 |
elif model_type == 'xformer':
|
112 |
from xformer import GPTConfig, GPT
|
113 |
model_config = GPTConfig(
|
|
|
157 |
glob.glob(os.path.join(data_dir, 'stable*.parquet')) + \
|
158 |
glob.glob(os.path.join(data_dir, 'anneal*.parquet'))
|
159 |
train_datasets = []
|
160 |
+
print("Loading dataset...")
|
161 |
for f in train_files:
|
162 |
dataset = pq.read_table(f).to_pandas()
|
163 |
dataset = dataset[dataset['tokenized'].apply(len) >= 8]
|
164 |
train_datasets.append(dataset)
|
165 |
+
print('.',end='',flush=True)
|
166 |
+
print("\nLoaded.")
|
167 |
#val_data = pq.read_table(os.path.join(data_dir, 'val.parquet')).to_pandas()
|
168 |
#val_data = val_data[val_data['tokenized'].apply(len) >= 8]
|
169 |
truncated_games_count = 0
|
|
|
225 |
else:
|
226 |
model_config.vocab_size = meta_vocab_size
|
227 |
if model_type == 'mamba':
|
228 |
+
#model = MambaLM(model_config)
|
229 |
+
model = MambaLMHeadModel(model_config)
|
230 |
else:
|
231 |
model = GPT(model_config)
|
232 |
if auto_clip:
|
|
|
242 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
243 |
model_config = checkpoint['model_args']
|
244 |
if model_type == 'mamba':
|
245 |
+
#model = MambaLM(model_config)
|
246 |
+
model = MambaLMHeadModel(model_config)
|
247 |
else:
|
248 |
model = GPT(model_config)
|
249 |
state_dict = checkpoint['model']
|
|
|
319 |
|
320 |
def batch_to_loss(sequences, max_length_in_batch):
|
321 |
if model_type == 'mamba':
|
322 |
+
logits = model(sequences[:, :-1]).logits # Forward pass, exclude last token for input
|
323 |
# Compute loss (assuming next token prediction task)
|
324 |
targets = sequences[:, 1:].reshape(-1) # Shifted by one for next token prediction
|
325 |
return F.cross_entropy(logits.view(-1, logits.size(-1)), targets)
|
326 |
+
#return F.cross_entropy(logits.reshape(-1), targets)
|
327 |
else:
|
328 |
inputs = sequences[:, :-1]
|
329 |
targets = sequences[:, 1:].reshape(-1)
|
|
|
485 |
scaler.update()
|
486 |
# flush the gradients as soon as we can, no need for this memory anymore
|
487 |
optimizer.zero_grad(set_to_none=True)
|
488 |
+
#torch.cuda.empty_cache()
|
489 |
|
490 |
# timing and logging
|
491 |
t1 = time.time()
|