|
import os |
|
import logging |
|
import sys |
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
from transformers import ( |
|
HfArgumentParser, |
|
AutoModelForMaskedLM, |
|
AutoTokenizer, |
|
set_seed, |
|
AutoConfig, |
|
DataCollatorForLanguageModeling, |
|
) |
|
|
|
from transformers import Trainer, TrainingArguments |
|
from datasets import load_dataset |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
level=logging.INFO, |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
""" |
|
Arguments which aren't included in the TrainingArguments |
|
""" |
|
resume_from_checkpoint: str = field(default=None) |
|
dataset_id: str = field( |
|
default=None, metadata={"help": "The repository id of the dataset to use (via the datasets library)."} |
|
) |
|
tokenizer_id: str = field( |
|
default=None, metadata={"help": "The repository id of the tokenizer to use (via AutoTokenizer)."} |
|
) |
|
repository_id: str = field( |
|
default=None, |
|
metadata={"help": "The repository id where the model will be saved or loaded from for futher pre-training."}, |
|
) |
|
model_config_id: Optional[str] = field( |
|
default="bert-base-uncased", metadata={"help": "Pretrained config name or path if not the same as model_name"} |
|
) |
|
per_device_train_batch_size: Optional[int] = field( |
|
default=16, |
|
metadata={"help": "The Batch Size per HPU used during training"}, |
|
) |
|
max_steps: Optional[int] = field( |
|
default=1_000_000, |
|
metadata={"help": "The Number of Training steps to perform."}, |
|
) |
|
learning_rate: Optional[float] = field(default=1e-4, metadata={"help": "Learning Rate for the training"}) |
|
mlm_probability: Optional[float] = field( |
|
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} |
|
) |
|
|
|
|
|
def run_mlm(): |
|
|
|
parser = HfArgumentParser(ScriptArguments) |
|
script_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)[0] |
|
logger.info(f"Script parameters {script_args}") |
|
|
|
|
|
seed = 34 |
|
set_seed(seed) |
|
|
|
|
|
train_dataset = load_dataset(script_args.dataset_id, split="train") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_id) |
|
|
|
|
|
logger.info("Training new model from scratch") |
|
config = AutoConfig.from_pretrained(script_args.model_config_id) |
|
model = AutoModelForMaskedLM.from_config(config) |
|
|
|
logger.info(f"Resizing token embedding to {len(tokenizer)}") |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=tokenizer, mlm_probability=script_args.mlm_probability, pad_to_multiple_of=8 |
|
) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=script_args.repository_id, |
|
per_device_train_batch_size=script_args.per_device_train_batch_size, |
|
learning_rate=script_args.learning_rate, |
|
seed=seed, |
|
max_steps=script_args.max_steps, |
|
|
|
logging_dir=f"{script_args.repository_id}/logs", |
|
logging_strategy="steps", |
|
logging_steps=100, |
|
save_strategy="steps", |
|
save_steps=5_000, |
|
save_total_limit=2, |
|
report_to="tensorboard", |
|
|
|
ddp_find_unused_parameters=True, |
|
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
args=training_args, |
|
model=model, |
|
train_dataset=train_dataset, |
|
tokenizer=tokenizer, |
|
data_collator=data_collator, |
|
) |
|
|
|
trainer.train(script_args.resume_from_checkpoint) |
|
|
|
|
|
if __name__ == "__main__": |
|
run_mlm() |
|
|