Spaces:
Runtime error
Runtime error
import logging | |
import os | |
import time | |
from argparse import Namespace | |
from pathlib import Path | |
import datasets | |
import torch | |
from accelerate import Accelerator, DistributedType | |
from accelerate.utils import ProjectConfiguration | |
from arguments import TrainingArguments | |
from datasets import load_dataset | |
from huggingface_hub import Repository | |
from torch.optim import AdamW | |
from torch.utils.data import IterableDataset | |
from torch.utils.data.dataloader import DataLoader | |
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe | |
import transformers | |
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed | |
class ConstantLengthDataset(IterableDataset): | |
""" | |
Iterable dataset that returns constant length chunks of tokens from stream of text files. | |
Args: | |
tokenizer (Tokenizer): The processor used for proccessing the data. | |
dataset (dataset.Dataset): Dataset with text files. | |
infinite (bool): If True the iterator is reset after dataset reaches end else stops. | |
seq_length (int): Length of token sequences to return. | |
num_of_sequences (int): Number of token sequences to keep in buffer. | |
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. | |
tokenized (bool): If true we use a pretokenized dataset. | |
""" | |
def __init__( | |
self, | |
tokenizer, | |
dataset, | |
infinite=False, | |
seq_length=1024, | |
num_of_sequences=1024, | |
chars_per_token=3.6, | |
tokenized=False, | |
): | |
self.tokenizer = tokenizer | |
self.concat_token_id = tokenizer.bos_token_id | |
self.dataset = dataset | |
self.seq_length = seq_length | |
self.epoch = 0 | |
self.infinite = infinite | |
self.current_size = 0 | |
self.tokenized = tokenized | |
if self.tokenized: | |
self.max_buffer_size = seq_length * num_of_sequences | |
self.content_field = "input_ids" | |
else: | |
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences | |
self.content_field = "content" | |
def __iter__(self): | |
iterator = iter(self.dataset) | |
more_examples = True | |
while more_examples: | |
buffer, buffer_len = [], 0 | |
while True: | |
if buffer_len >= self.max_buffer_size: | |
break | |
try: | |
buffer.append(next(iterator)[self.content_field]) | |
buffer_len += len(buffer[-1]) | |
except StopIteration: | |
if self.infinite: | |
iterator = iter(self.dataset) | |
self.epoch += 1 | |
logger.info(f"Dataset epoch: {self.epoch}") | |
else: | |
more_examples = False | |
break | |
if self.tokenized: | |
tokenized_inputs = buffer | |
else: | |
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] | |
all_token_ids = [] | |
for tokenized_input in tokenized_inputs: | |
all_token_ids.extend(tokenized_input + [self.concat_token_id]) | |
for i in range(0, len(all_token_ids), self.seq_length): | |
input_ids = all_token_ids[i : i + self.seq_length] | |
if len(input_ids) == self.seq_length: | |
self.current_size += 1 | |
yield torch.tensor(input_ids) | |
def shuffle(self, buffer_size=1000): | |
return ShufflerIterDataPipe(self, buffer_size=buffer_size) | |
def setup_logging(args): | |
project_name = args.model_ckpt.split("/")[-1] | |
logger = logging.getLogger(__name__) | |
log_dir = Path(args.save_dir) / "log/" | |
log_dir.mkdir(exist_ok=True) | |
filename = f"debug_{accelerator.process_index}.log" | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()], | |
) | |
if accelerator.is_main_process: # we only want to setup logging once | |
accelerator.init_trackers(project_name, vars(args)) | |
run_name = accelerator.trackers[0].run.name | |
logger.setLevel(logging.INFO) | |
datasets.utils.logging.set_verbosity_info() | |
transformers.utils.logging.set_verbosity_info() | |
else: | |
run_name = "" | |
logger.setLevel(logging.ERROR) | |
datasets.utils.logging.set_verbosity_error() | |
transformers.utils.logging.set_verbosity_error() | |
return logger, run_name | |
def create_dataloaders(args): | |
ds_kwargs = {"streaming": True} | |
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs) | |
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) | |
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs) | |
train_dataset = ConstantLengthDataset( | |
tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized | |
) | |
valid_dataset = ConstantLengthDataset( | |
tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized | |
) | |
train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer) | |
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) | |
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size) | |
return train_dataloader, eval_dataloader | |
def get_grouped_params(model, args, no_decay=["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]): | |
params_with_wd, params_without_wd = [], [] | |
for n, p in model.named_parameters(): | |
if any(nd in n for nd in no_decay): | |
params_without_wd.append(p) | |
else: | |
params_with_wd.append(p) | |
return [ | |
{"params": params_with_wd, "weight_decay": args.weight_decay}, | |
{"params": params_without_wd, "weight_decay": 0.0}, | |
] | |
def log_metrics(step, metrics): | |
logger.info(f"Step {step}: {metrics}") | |
if accelerator.is_main_process: | |
accelerator.log(metrics, step) | |
def compute_tflops(elapsed_time, accelerator, args): | |
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf). | |
config_model = accelerator.unwrap_model(model).config | |
checkpoint_factor = 4 if args.gradient_checkpointing else 3 | |
batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps | |
factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2) | |
flops_per_iteration = factor * ( | |
1.0 | |
+ (args.seq_length / (6.0 * config_model.n_embd)) | |
+ (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd)) | |
) | |
tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12)) | |
return tflops | |
def evaluate(args): | |
model.eval() | |
losses = [] | |
for step, batch in enumerate(eval_dataloader): | |
with torch.no_grad(): | |
outputs = model(batch, labels=batch) | |
loss = outputs.loss.repeat(args.valid_batch_size) | |
losses.append(accelerator.gather(loss)) | |
if args.max_eval_steps > 0 and step >= args.max_eval_steps: | |
break | |
losses = torch.cat(losses) | |
loss = losses[: eval_dataloader.dataset.current_size].mean() | |
try: | |
perplexity = torch.exp(loss) | |
except OverflowError: | |
perplexity = float("inf") | |
return loss.item(), perplexity.item() | |
# Settings | |
parser = HfArgumentParser(TrainingArguments) | |
args = parser.parse_args() | |
# Accelerator | |
config = ProjectConfiguration(project_dir=args.save_dir, logging_dir="log") | |
accelerator = Accelerator(log_with=["wandb", "tensorboard"], project_config=config) | |
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()} | |
args = Namespace(**vars(args), **acc_state) | |
samples_per_step = accelerator.state.num_processes * args.train_batch_size | |
set_seed(args.seed) | |
# Clone model repository | |
if accelerator.is_main_process: | |
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt) | |
# Logging | |
logger, run_name = setup_logging(args) | |
logger.info(accelerator.state) | |
# Checkout new branch on repo | |
if accelerator.is_main_process: | |
hf_repo.git_checkout(run_name, create_branch_ok=True) | |
# Load model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained(args.save_dir) | |
if args.gradient_checkpointing: | |
model.gradient_checkpointing_enable() | |
tokenizer = AutoTokenizer.from_pretrained(args.save_dir) | |
# Load dataset and dataloader | |
train_dataloader, eval_dataloader = create_dataloaders(args) | |
# Prepare the optimizer and learning rate scheduler | |
optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate) | |
lr_scheduler = get_scheduler( | |
name=args.lr_scheduler_type, | |
optimizer=optimizer, | |
num_warmup_steps=args.num_warmup_steps, | |
num_training_steps=args.max_train_steps, | |
) | |
accelerator.register_for_checkpointing(lr_scheduler) | |
def get_lr(): | |
return optimizer.param_groups[0]["lr"] | |
# Prepare everything with our `accelerator`. | |
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( | |
model, optimizer, train_dataloader, eval_dataloader | |
) | |
# load in the weights and states from a previous save | |
if args.resume_from_checkpoint: | |
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": | |
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") | |
accelerator.load_state(args.resume_from_checkpoint) | |
path = os.path.basename(args.resume_from_checkpoint) | |
else: | |
# Get the most recent checkpoint | |
dirs = [f.name for f in os.scandir(args.save_dir) if f.is_dir() and "step" in str(f)] | |
dirs.sort(key=os.path.getctime) | |
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last | |
# Extract the step of the checkpoint to continue from there | |
training_difference = os.path.splitext(path)[0] | |
resume_step = int(training_difference.replace("step_", "")) | |
# Train model | |
model.train() | |
completed_steps = 0 | |
t_start = time.time() | |
loss_tracking = 0 | |
for step, batch in enumerate(train_dataloader, start=1): | |
if args.resume_from_checkpoint and step < resume_step: | |
continue # we need to skip steps until we reach the resumed step | |
loss = model(batch, labels=batch, use_cache=False).loss | |
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() | |
loss_tracking += avg_loss.item() / args.gradient_accumulation_steps | |
log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()}) | |
loss = loss / args.gradient_accumulation_steps | |
if step % args.gradient_accumulation_steps != 0: | |
# Prevent backward from doing gradient all_reduce in every step | |
if accelerator.distributed_type == DistributedType.MULTI_GPU: | |
with model.no_sync(): | |
accelerator.backward(loss) | |
else: | |
accelerator.backward(loss) | |
else: | |
lr = get_lr() | |
accelerator.backward(loss) | |
accelerator.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
elapsed_time = time.time() - t_start | |
tflops = compute_tflops(elapsed_time, accelerator, args) | |
log_metrics( | |
step, | |
{ | |
"steps": completed_steps, | |
"loss/train": loss_tracking, | |
"lr": lr, | |
"tflops": tflops, | |
"time_per_iteration": elapsed_time, | |
}, | |
) | |
t_start = time.time() | |
loss_tracking = 0 | |
completed_steps += 1 | |
if step % args.save_checkpoint_steps == 0: | |
logger.info("Evaluating and saving model checkpoint") | |
eval_loss, perplexity = evaluate(args) | |
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) | |
accelerator.wait_for_everyone() | |
save_dir = os.path.join(args.save_dir, f"step_{step}") | |
accelerator.save_state(save_dir) | |
if accelerator.is_main_process: | |
hf_repo.push_to_hub(commit_message=f"step {step}") | |
model.train() | |
if completed_steps >= args.max_train_steps: | |
break | |
# Evaluate and save the last checkpoint | |
logger.info("Evaluating and saving model after training") | |
eval_loss, perplexity = evaluate(args) | |
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) | |
accelerator.wait_for_everyone() | |
unwrapped_model = accelerator.unwrap_model(model) | |
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save) | |
save_dir = os.path.join(args.save_dir, f"step_{step}") | |
accelerator.save_state(save_dir) | |
if accelerator.is_main_process: | |
hf_repo.push_to_hub(commit_message="final model") | |