|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a |
|
text file or a dataset. |
|
|
|
Here is the full list of checkpoints on the hub that can be fine-tuned by this script: |
|
https://huggingface.co/models?filter=masked-lm |
|
""" |
|
import logging |
|
import os |
|
import sys |
|
import time |
|
from dataclasses import dataclass, field |
|
|
|
|
|
from pathlib import Path |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import numpy as np |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
|
|
import flax |
|
import jax |
|
import jax.numpy as jnp |
|
import optax |
|
from flax import jax_utils, traverse_util |
|
from flax.training import train_state |
|
from flax.training.common_utils import get_metrics, onehot, shard |
|
from transformers import ( |
|
CONFIG_MAPPING, |
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING, |
|
AutoConfig, |
|
AutoTokenizer, |
|
FlaxAutoModelForMaskedLM, |
|
HfArgumentParser, |
|
PreTrainedTokenizerBase, |
|
TensorType, |
|
TrainingArguments, |
|
is_tensorboard_available, |
|
set_seed, |
|
) |
|
|
|
|
|
|
|
has_tensorboard = is_tensorboard_available() |
|
if has_tensorboard: |
|
try: |
|
from flax.metrics.tensorboard import SummaryWriter |
|
except ImportError as ie: |
|
has_tensorboard = False |
|
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}") |
|
|
|
else: |
|
print( |
|
"Unable to display metrics through TensorBoard because the package is not installed: " |
|
"Please run pip install tensorboard to enable." |
|
) |
|
|
|
|
|
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys()) |
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
|
|
|
|
|
@dataclass |
|
class ModelArguments: |
|
""" |
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. |
|
""" |
|
|
|
model_name_or_path: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The model checkpoint for weights initialization." |
|
"Don't set if you want to train a model from scratch." |
|
}, |
|
) |
|
model_type: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, |
|
) |
|
config_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} |
|
) |
|
tokenizer_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
|
) |
|
cache_dir: Optional[str] = field( |
|
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} |
|
) |
|
use_fast_tokenizer: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, |
|
) |
|
dtype: Optional[str] = field( |
|
default="float32", |
|
metadata={ |
|
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." |
|
}, |
|
) |
|
|
|
|
|
@dataclass |
|
class DataTrainingArguments: |
|
""" |
|
Arguments pertaining to what data we are going to input our model for training and eval. |
|
""" |
|
|
|
dataset_name: Optional[str] = field( |
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
|
) |
|
dataset_config_name: Optional[str] = field( |
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} |
|
) |
|
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) |
|
validation_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, |
|
) |
|
train_ref_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."}, |
|
) |
|
validation_ref_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."}, |
|
) |
|
overwrite_cache: bool = field( |
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
|
) |
|
validation_split_percentage: Optional[int] = field( |
|
default=5, |
|
metadata={ |
|
"help": "The percentage of the train set used as validation set in case there's no validation split" |
|
}, |
|
) |
|
max_seq_length: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The maximum total input sequence length after tokenization. Sequences longer " |
|
"than this will be truncated. Default to the max input length of the model." |
|
}, |
|
) |
|
preprocessing_num_workers: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "The number of processes to use for the preprocessing."}, |
|
) |
|
mlm_probability: float = field( |
|
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} |
|
) |
|
pad_to_max_length: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to pad all samples to `max_seq_length`. " |
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch." |
|
}, |
|
) |
|
line_by_line: bool = field( |
|
default=False, |
|
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, |
|
) |
|
|
|
def __post_init__(self): |
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None: |
|
raise ValueError("Need either a dataset name or a training/validation file.") |
|
else: |
|
if self.train_file is not None: |
|
extension = self.train_file.split(".")[-1] |
|
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." |
|
if self.validation_file is not None: |
|
extension = self.validation_file.split(".")[-1] |
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." |
|
|
|
|
|
@flax.struct.dataclass |
|
class FlaxDataCollatorForLanguageModeling: |
|
""" |
|
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they |
|
are not all of the same length. |
|
|
|
Args: |
|
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): |
|
The tokenizer used for encoding the data. |
|
mlm_probability (:obj:`float`, `optional`, defaults to 0.15): |
|
The probability with which to (randomly) mask tokens in the input. |
|
|
|
.. note:: |
|
|
|
For best performance, this data collator should be used with a dataset having items that are dictionaries or |
|
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a |
|
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the |
|
argument :obj:`return_special_tokens_mask=True`. |
|
""" |
|
|
|
tokenizer: PreTrainedTokenizerBase |
|
mlm_probability: float = 0.15 |
|
|
|
def __post_init__(self): |
|
if self.tokenizer.mask_token is None: |
|
raise ValueError( |
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. " |
|
"You should pass `mlm=False` to train on causal language modeling instead." |
|
) |
|
|
|
def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]: |
|
|
|
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY) |
|
|
|
|
|
special_tokens_mask = batch.pop("special_tokens_mask", None) |
|
|
|
batch["input_ids"], batch["labels"] = self.mask_tokens( |
|
batch["input_ids"], special_tokens_mask=special_tokens_mask |
|
) |
|
return batch |
|
|
|
def mask_tokens( |
|
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray] |
|
) -> Tuple[jnp.ndarray, jnp.ndarray]: |
|
""" |
|
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. |
|
""" |
|
labels = inputs.copy() |
|
|
|
probability_matrix = np.full(labels.shape, self.mlm_probability) |
|
special_tokens_mask = special_tokens_mask.astype("bool") |
|
|
|
probability_matrix[special_tokens_mask] = 0.0 |
|
masked_indices = np.random.binomial(1, probability_matrix).astype("bool") |
|
labels[~masked_indices] = -100 |
|
|
|
|
|
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices |
|
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) |
|
|
|
|
|
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool") |
|
indices_random &= masked_indices & ~indices_replaced |
|
|
|
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4") |
|
inputs[indices_random] = random_words[indices_random] |
|
|
|
|
|
return inputs, labels |
|
|
|
|
|
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: |
|
num_samples = len(samples_idx) |
|
samples_to_remove = num_samples % batch_size |
|
|
|
if samples_to_remove != 0: |
|
samples_idx = samples_idx[:-samples_to_remove] |
|
sections_split = num_samples // batch_size |
|
batch_idx = np.split(samples_idx, sections_split) |
|
return batch_idx |
|
|
|
|
|
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): |
|
summary_writer.scalar("train_time", train_time, step) |
|
|
|
train_metrics = get_metrics(train_metrics) |
|
for key, vals in train_metrics.items(): |
|
tag = f"train_{key}" |
|
for i, val in enumerate(vals): |
|
summary_writer.scalar(tag, val, step - len(vals) + i + 1) |
|
|
|
for metric_name, value in eval_metrics.items(): |
|
summary_writer.scalar(f"eval_{metric_name}", value, step) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) |
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
|
|
|
|
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
else: |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
if ( |
|
os.path.exists(training_args.output_dir) |
|
and os.listdir(training_args.output_dir) |
|
and training_args.do_train |
|
and not training_args.overwrite_output_dir |
|
): |
|
raise ValueError( |
|
f"Output directory ({training_args.output_dir}) already exists and is not empty." |
|
"Use --overwrite_output_dir to overcome." |
|
) |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
level="NOTSET", |
|
datefmt="[%X]", |
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Training/evaluation parameters {training_args}") |
|
|
|
|
|
set_seed(training_args.seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if data_args.dataset_name is not None: |
|
|
|
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) |
|
|
|
if "validation" not in datasets.keys(): |
|
datasets["validation"] = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
split=f"train[:{data_args.validation_split_percentage}%]", |
|
cache_dir=model_args.cache_dir, |
|
) |
|
datasets["train"] = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
split=f"train[{data_args.validation_split_percentage}%:]", |
|
cache_dir=model_args.cache_dir, |
|
) |
|
else: |
|
data_files = {} |
|
if data_args.train_file is not None: |
|
data_files["train"] = data_args.train_file |
|
if data_args.validation_file is not None: |
|
data_files["validation"] = data_args.validation_file |
|
extension = data_args.train_file.split(".")[-1] |
|
if extension == "txt": |
|
extension = "text" |
|
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_args.config_name: |
|
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) |
|
elif model_args.model_name_or_path: |
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) |
|
else: |
|
config = CONFIG_MAPPING[model_args.model_type]() |
|
logger.warning("You are instantiating a new config instance from scratch.") |
|
|
|
if model_args.tokenizer_name: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer |
|
) |
|
elif model_args.model_name_or_path: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer |
|
) |
|
else: |
|
raise ValueError( |
|
"You are instantiating a new tokenizer from scratch. This is not supported by this script." |
|
"You can do it from another script, save it, and load it from here, using --tokenizer_name." |
|
) |
|
|
|
|
|
|
|
if training_args.do_train: |
|
column_names = datasets["train"].column_names |
|
else: |
|
column_names = datasets["validation"].column_names |
|
text_column_name = "text" if "text" in column_names else column_names[0] |
|
|
|
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) |
|
|
|
if data_args.line_by_line: |
|
|
|
padding = "max_length" if data_args.pad_to_max_length else False |
|
|
|
def tokenize_function(examples): |
|
|
|
examples = [line for line in examples if len(line) > 0 and not line.isspace()] |
|
return tokenizer( |
|
examples, |
|
return_special_tokens_mask=True, |
|
padding=padding, |
|
truncation=True, |
|
max_length=max_seq_length, |
|
) |
|
|
|
tokenized_datasets = datasets.map( |
|
tokenize_function, |
|
input_columns=[text_column_name], |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=column_names, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
) |
|
|
|
else: |
|
|
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples[text_column_name], return_special_tokens_mask=True) |
|
|
|
tokenized_datasets = datasets.map( |
|
tokenize_function, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=column_names, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
) |
|
|
|
|
|
|
|
def group_texts(examples): |
|
|
|
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} |
|
total_length = len(concatenated_examples[list(examples.keys())[0]]) |
|
|
|
|
|
total_length = (total_length // max_seq_length) * max_seq_length |
|
|
|
result = { |
|
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] |
|
for k, t in concatenated_examples.items() |
|
} |
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenized_datasets = tokenized_datasets.map( |
|
group_texts, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
) |
|
|
|
|
|
if has_tensorboard and jax.process_index() == 0: |
|
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) |
|
|
|
|
|
|
|
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability) |
|
|
|
|
|
rng = jax.random.PRNGKey(training_args.seed) |
|
dropout_rngs = jax.random.split(rng, jax.local_device_count()) |
|
|
|
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) |
|
|
|
|
|
num_epochs = int(training_args.num_train_epochs) |
|
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() |
|
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() |
|
|
|
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs |
|
|
|
|
|
warmup_fn = optax.linear_schedule( |
|
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps |
|
) |
|
decay_fn = optax.linear_schedule( |
|
init_value=training_args.learning_rate, |
|
end_value=0, |
|
transition_steps=num_train_steps - training_args.warmup_steps, |
|
) |
|
linear_decay_lr_schedule_fn = optax.join_schedules( |
|
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decay_mask_fn(params): |
|
flat_params = traverse_util.flatten_dict(params) |
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} |
|
return traverse_util.unflatten_dict(flat_mask) |
|
|
|
|
|
adamw = optax.adamw( |
|
learning_rate=linear_decay_lr_schedule_fn, |
|
b1=training_args.adam_beta1, |
|
b2=training_args.adam_beta2, |
|
eps=1e-8, |
|
weight_decay=training_args.weight_decay, |
|
mask=decay_mask_fn, |
|
) |
|
|
|
|
|
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) |
|
|
|
|
|
def train_step(state, batch, dropout_rng): |
|
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) |
|
|
|
def loss_fn(params): |
|
labels = batch.pop("labels") |
|
|
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] |
|
|
|
|
|
label_mask = jnp.where(labels > 0, 1.0, 0.0) |
|
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask |
|
|
|
|
|
loss = loss.sum() / label_mask.sum() |
|
|
|
return loss |
|
|
|
grad_fn = jax.value_and_grad(loss_fn) |
|
loss, grad = grad_fn(state.params) |
|
grad = jax.lax.pmean(grad, "batch") |
|
new_state = state.apply_gradients(grads=grad) |
|
|
|
metrics = jax.lax.pmean( |
|
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" |
|
) |
|
|
|
return new_state, metrics, new_dropout_rng |
|
|
|
|
|
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) |
|
|
|
|
|
def eval_step(params, batch): |
|
labels = batch.pop("labels") |
|
|
|
logits = model(**batch, params=params, train=False)[0] |
|
|
|
|
|
label_mask = jnp.where(labels > 0, 1.0, 0.0) |
|
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask |
|
|
|
|
|
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask |
|
|
|
|
|
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()} |
|
metrics = jax.lax.psum(metrics, axis_name="batch") |
|
|
|
return metrics |
|
|
|
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,)) |
|
|
|
|
|
state = jax_utils.replicate(state) |
|
|
|
train_time = 0 |
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) |
|
for epoch in epochs: |
|
|
|
train_start = time.time() |
|
train_metrics = [] |
|
|
|
|
|
rng, input_rng = jax.random.split(rng) |
|
|
|
|
|
num_train_samples = len(tokenized_datasets["train"]) |
|
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) |
|
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) |
|
|
|
|
|
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): |
|
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] |
|
model_inputs = data_collator(samples, pad_to_multiple_of=16) |
|
|
|
|
|
model_inputs = shard(model_inputs.data) |
|
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) |
|
train_metrics.append(train_metric) |
|
|
|
train_time += time.time() - train_start |
|
|
|
epochs.write( |
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" |
|
) |
|
|
|
|
|
num_eval_samples = len(tokenized_datasets["validation"]) |
|
eval_samples_idx = jnp.arange(num_eval_samples) |
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) |
|
|
|
eval_metrics = [] |
|
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): |
|
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] |
|
model_inputs = data_collator(samples, pad_to_multiple_of=16) |
|
|
|
|
|
model_inputs = shard(model_inputs.data) |
|
metrics = p_eval_step(state.params, model_inputs) |
|
eval_metrics.append(metrics) |
|
|
|
|
|
eval_metrics = get_metrics(eval_metrics) |
|
eval_metrics = jax.tree_map(jnp.sum, eval_metrics) |
|
eval_normalizer = eval_metrics.pop("normalizer") |
|
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) |
|
|
|
|
|
epochs.desc = ( |
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" |
|
) |
|
|
|
|
|
if has_tensorboard and jax.process_index() == 0: |
|
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) |
|
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) |
|
|
|
|
|
if jax.process_index() == 0: |
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) |
|
model.save_pretrained( |
|
training_args.output_dir, |
|
params=params, |
|
push_to_hub=training_args.push_to_hub, |
|
commit_message=f"Saving weights and logs of epoch {epoch+1}", |
|
) |
|
|