MasaakiKotera
commited on
Commit
•
b612eb3
1
Parent(s):
1b3f51e
Upload train.py with huggingface_hub
Browse files
train.py
CHANGED
@@ -41,9 +41,6 @@ log_and_write(log_dir, f'training data: {data_dir}')
|
|
41 |
# -----------------------------------------------------------------------------
|
42 |
|
43 |
|
44 |
-
# various inits, derived attributes, I/O setup
|
45 |
-
# ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
|
46 |
-
|
47 |
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
|
48 |
if ddp:
|
49 |
init_process_group(backend=backend)
|
@@ -53,13 +50,10 @@ if ddp:
|
|
53 |
device = f'cuda:{ddp_local_rank}'
|
54 |
torch.cuda.set_device(device)
|
55 |
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
56 |
-
seed_offset = ddp_rank
|
57 |
-
# world_size number of processes will be training simultaneously, so we can scale
|
58 |
-
# down the desired gradient accumulation iterations per process proportionally
|
59 |
assert gradient_accumulation_steps % ddp_world_size == 0
|
60 |
gradient_accumulation_steps //= ddp_world_size
|
61 |
else:
|
62 |
-
# if not ddp, we are running on a single gpu, and one process
|
63 |
master_process = True
|
64 |
seed_offset = 0
|
65 |
ddp_world_size = 1
|
@@ -85,7 +79,6 @@ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torc
|
|
85 |
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
86 |
|
87 |
# data loader
|
88 |
-
# data_dir = os.path.join('data', dataset)
|
89 |
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
90 |
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
91 |
def get_batch(split):
|
@@ -100,7 +93,6 @@ def get_batch(split):
|
|
100 |
x, y = x.to(device), y.to(device)
|
101 |
return x, y
|
102 |
|
103 |
-
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
|
104 |
iter_num = 0
|
105 |
best_val_loss = 1e9
|
106 |
|
@@ -127,7 +119,6 @@ if init_from == 'scratch':
|
|
127 |
elif init_from == 'resume':
|
128 |
print(f"Resuming training from {out_dir}")
|
129 |
# resume training from a checkpoint.
|
130 |
-
# ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
131 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
132 |
checkpoint_model_args = checkpoint['model_args']
|
133 |
# force these config attributes to be equal otherwise we can't even resume training
|
@@ -138,8 +129,6 @@ elif init_from == 'resume':
|
|
138 |
gptconf = GPTConfig(**model_args)
|
139 |
model = GPT(gptconf)
|
140 |
state_dict = checkpoint['model']
|
141 |
-
# fix the keys of the state dictionary :(
|
142 |
-
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
|
143 |
unwanted_prefix = '_orig_mod.'
|
144 |
for k,v in list(state_dict.items()):
|
145 |
if k.startswith(unwanted_prefix):
|
@@ -147,14 +136,6 @@ elif init_from == 'resume':
|
|
147 |
model.load_state_dict(state_dict)
|
148 |
iter_num = checkpoint['iter_num']
|
149 |
best_val_loss = checkpoint['best_val_loss']
|
150 |
-
elif init_from.startswith('gpt2'):
|
151 |
-
print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
|
152 |
-
# initialize from OpenAI GPT-2 weights
|
153 |
-
override_args = dict(dropout=dropout)
|
154 |
-
model = GPT.from_pretrained(init_from, override_args)
|
155 |
-
# read off the created config params, so we can store them into checkpoint correctly
|
156 |
-
for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
|
157 |
-
model_args[k] = getattr(model.config, k)
|
158 |
# crop down the model block size if desired, using model surgery
|
159 |
if block_size < model.config.block_size:
|
160 |
model.crop_block_size(block_size)
|
@@ -188,7 +169,7 @@ def estimate_loss():
|
|
188 |
model.eval()
|
189 |
for split in ['train', 'val']:
|
190 |
losses = torch.zeros(eval_iters)
|
191 |
-
total_loss = 0
|
192 |
for k in range(eval_iters):
|
193 |
X, Y = get_batch(split)
|
194 |
with ctx:
|
@@ -197,7 +178,7 @@ def estimate_loss():
|
|
197 |
total_loss += loss.item()
|
198 |
avg_loss = losses.mean()
|
199 |
out[split] = avg_loss
|
200 |
-
perplexities[split] = torch.exp(avg_loss)
|
201 |
model.train()
|
202 |
return out, perplexities
|
203 |
|
@@ -235,19 +216,20 @@ while True:
|
|
235 |
log_and_write(log_dir, f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f},train perplexity: {perplexities['train']:.4f}, val perplexity: {perplexities['val']:.4f}")
|
236 |
if iter_num % 200 == 0:
|
237 |
print_gpu_memory_usage()
|
238 |
-
if
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
|
|
251 |
if iter_num == 0 and eval_only:
|
252 |
break
|
253 |
|
|
|
41 |
# -----------------------------------------------------------------------------
|
42 |
|
43 |
|
|
|
|
|
|
|
44 |
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
|
45 |
if ddp:
|
46 |
init_process_group(backend=backend)
|
|
|
50 |
device = f'cuda:{ddp_local_rank}'
|
51 |
torch.cuda.set_device(device)
|
52 |
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
53 |
+
seed_offset = ddp_rank
|
|
|
|
|
54 |
assert gradient_accumulation_steps % ddp_world_size == 0
|
55 |
gradient_accumulation_steps //= ddp_world_size
|
56 |
else:
|
|
|
57 |
master_process = True
|
58 |
seed_offset = 0
|
59 |
ddp_world_size = 1
|
|
|
79 |
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
80 |
|
81 |
# data loader
|
|
|
82 |
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
83 |
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
84 |
def get_batch(split):
|
|
|
93 |
x, y = x.to(device), y.to(device)
|
94 |
return x, y
|
95 |
|
|
|
96 |
iter_num = 0
|
97 |
best_val_loss = 1e9
|
98 |
|
|
|
119 |
elif init_from == 'resume':
|
120 |
print(f"Resuming training from {out_dir}")
|
121 |
# resume training from a checkpoint.
|
|
|
122 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
123 |
checkpoint_model_args = checkpoint['model_args']
|
124 |
# force these config attributes to be equal otherwise we can't even resume training
|
|
|
129 |
gptconf = GPTConfig(**model_args)
|
130 |
model = GPT(gptconf)
|
131 |
state_dict = checkpoint['model']
|
|
|
|
|
132 |
unwanted_prefix = '_orig_mod.'
|
133 |
for k,v in list(state_dict.items()):
|
134 |
if k.startswith(unwanted_prefix):
|
|
|
136 |
model.load_state_dict(state_dict)
|
137 |
iter_num = checkpoint['iter_num']
|
138 |
best_val_loss = checkpoint['best_val_loss']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
# crop down the model block size if desired, using model surgery
|
140 |
if block_size < model.config.block_size:
|
141 |
model.crop_block_size(block_size)
|
|
|
169 |
model.eval()
|
170 |
for split in ['train', 'val']:
|
171 |
losses = torch.zeros(eval_iters)
|
172 |
+
total_loss = 0
|
173 |
for k in range(eval_iters):
|
174 |
X, Y = get_batch(split)
|
175 |
with ctx:
|
|
|
178 |
total_loss += loss.item()
|
179 |
avg_loss = losses.mean()
|
180 |
out[split] = avg_loss
|
181 |
+
perplexities[split] = torch.exp(avg_loss)
|
182 |
model.train()
|
183 |
return out, perplexities
|
184 |
|
|
|
216 |
log_and_write(log_dir, f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f},train perplexity: {perplexities['train']:.4f}, val perplexity: {perplexities['val']:.4f}")
|
217 |
if iter_num % 200 == 0:
|
218 |
print_gpu_memory_usage()
|
219 |
+
if always_save_checkpoint:
|
220 |
+
if losses['val'] < best_val_loss or always_save_checkpoint:
|
221 |
+
best_val_loss = losses['val']
|
222 |
+
if iter_num > 0:
|
223 |
+
checkpoint = {
|
224 |
+
'model': raw_model.state_dict(),
|
225 |
+
'optimizer': optimizer.state_dict(),
|
226 |
+
'model_args': model_args,
|
227 |
+
'iter_num': iter_num,
|
228 |
+
'best_val_loss': best_val_loss,
|
229 |
+
'config': config,
|
230 |
+
}
|
231 |
+
log_and_write(log_dir, f"saving checkpoint to {out_dir}")
|
232 |
+
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{iter_num}.pt'))
|
233 |
if iter_num == 0 and eval_only:
|
234 |
break
|
235 |
|