In [None]:
import argparse
import glob
import logging
import os
import pickle
import random
import re
import shutil
from typing import Dict, List, Tuple
from copy import deepcopy
from multiprocessing import Pool

import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModel
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    BertConfig,
    BertForMaskedLM,
    BertTokenizer,
    CamembertConfig,
    CamembertForMaskedLM,
    CamembertTokenizer,
    DistilBertConfig,
    DistilBertForMaskedLM,
    DistilBertTokenizer,
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    OpenAIGPTConfig,
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    RobertaConfig,
    RobertaForMaskedLM,
    RobertaTokenizer,
    get_linear_schedule_with_warmup,
    get_cosine_with_hard_restarts_schedule_with_warmup
)


try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter


logger = logging.getLogger(__name__)

DNATokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)


MODEL_CLASSES = {
    "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    "dna": (BertConfig, BertForMaskedLM, DNATokenizer),
    "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
    "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
    "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
}

MASK_LIST = {
    "3": [-1, 1],
    "4": [-1, 1, 2],
    "5": [-2, -1, 1, 2],
    "6": [-2, -1, 1, 2, 3]
}


class TextDataset(Dataset):
    def __init__(self, tokenizer: PreTrainedTokenizer, config, file_path: str, block_size=512):
        assert os.path.isfile(file_path)


        directory, filename = os.path.split(file_path)
        cached_features_file = os.path.join(
            directory, dna + "_cached_lm_" + str(block_size) + "_" + filename
        )

        if os.path.exists(cached_features_file) and not config['overwrite_cache']:
            logger.info("Loading features from cached file %s", cached_features_file)
            with open(cached_features_file, "rb") as handle:
                self.examples = pickle.load(handle)
        else:
            logger.info("Creating features from dataset file at %s", directory)

            self.examples = []
            with open(file_path, encoding="utf-8") as f:
                text = f.read()

            tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))

            for i in range(0, len(tokenized_text) - block_size + 1, block_size):  # Truncate in block of block_size
                self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size]))
            logger.info("Saving features into cached file %s", cached_features_file)
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
        return torch.tensor(self.examples[item], dtype=torch.long)

def convert_line_to_example(tokenizer, lines, max_length, add_special_tokens=True):
    examples = tokenizer.batch_encode_plus(lines, add_special_tokens=add_special_tokens, max_length=max_length)["input_ids"]
    return examples

class LineByLineTextDataset(Dataset):
    def __init__(self, tokenizer: PreTrainedTokenizer, config, file_path: str, block_size=512):
        assert os.path.isfile(file_path)
        # Here, we do not cache the features, operating under the assumption
        # that we will soon use fast multithreaded tokenizers from the
        # `tokenizers` repo everywhere =)
        directory, filename = os.path.split(file_path)
        cached_features_file = os.path.join(
            '/kaggle/working/', 'dna' + "_cached_lm_" + str(block_size) + "_" + filename
        )

        if os.path.exists(cached_features_file) and not config['overwrite_cache']:
            logger.info("Loading features from cached file %s", cached_features_file)
            with open(cached_features_file, "rb") as handle:
                self.examples = pickle.load(handle)
        else:
            logger.info("Creating features from dataset file at %s", file_path)

            with open(file_path, encoding="utf-8") as f:
                lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
            
            if config['n_process'] == 1:
                self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)["input_ids"]
            else:
                n_proc = config['n_process']
                p = Pool(n_proc)
                indexes = [0]
                len_slice = int(len(lines)/n_proc)
                for i in range(1, n_proc+1):
                    if i != n_proc:
                        indexes.append(len_slice*(i))
                    else:
                        indexes.append(len(lines))
                results = []
                for i in range(n_proc):
                    results.append(p.apply_async(convert_line_to_example,[tokenizer, lines[indexes[i]:indexes[i+1]], block_size,]))
                    print(str(i) + " start")
                p.close() 
                p.join()

                self.examples = []
                for result in results:
                    ids = result.get()
                    self.examples.extend(ids)

            logger.info("Saving features into cached file %s", cached_features_file)
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        return torch.tensor(self.examples[i], dtype=torch.long)


def load_and_cache_examples(config, tokenizer, evaluate=False):
    file_path = r"/kaggle/input/random-dna-sequences-for-transfomer-pretraining/6_12k.txt" if evaluate else r'/kaggle/input/random-dna-sequences-for-transfomer-pretraining/6_12k.txt'
    if config['line_by_line']:
        return LineByLineTextDataset(tokenizer, config, file_path=file_path, block_size=config['block_size'])
    else:
        return TextDataset(tokenizer, config, file_path=file_path, block_size=config['block_size'])


def set_seed(config):
    random.seed(config['seed'])
    np.random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    if config['n_gpu'] > 0:
        torch.cuda.manual_seed_all(config['seed'])


def _sorted_checkpoints(config, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
    ordering_and_checkpoint_path = []
    st = r"/kaggle/working/output"
    
    glob_checkpoints = glob.glob(os.path.join(st, "{}-*".format(checkpoint_prefix)))

    for path in glob_checkpoints:
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
    return checkpoints_sorted


def _rotate_checkpoints(config, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
    if not config['save_total_limit']:
        return
    if config['save_total_limit'] <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(config, checkpoint_prefix, use_mtime)
    if len(checkpoints_sorted) <= config['save_total_limit']:
        return

    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - config['save_total_limit'])
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to config['save_total_limit']".format(checkpoint))
        shutil.rmtree(checkpoint)




def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, config) -> Tuple[torch.Tensor, torch.Tensor]:
    """Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original."""

    mask_list = MASK_LIST['6']

    if tokenizer.mask_token is None:
        raise ValueError(
            "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
        )

    labels = inputs.clone()
    # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
    probability_matrix = torch.full(labels.shape, config['mlm_probability'])
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
    if tokenizer.pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)

    masked_indices = torch.bernoulli(probability_matrix).bool()

    # Ensure masked_indices and probability_matrix are the same shape
    masks = deepcopy(masked_indices)
    for i, masked_index in enumerate(masks):
        # Ensure there are non-zero elements to avoid IndexError
        non_zero_indices = torch.where(probability_matrix[i] != 0)[0]
        if non_zero_indices.numel() == 0:
            # If no non-zero elements, skip this sequence
            continue

        end = non_zero_indices.tolist()[-1]
        mask_centers = set(torch.where(masked_index == 1)[0].tolist())
        new_centers = deepcopy(mask_centers)
        for center in mask_centers:
            for mask_number in mask_list:
                current_index = center + mask_number
                if current_index <= end and current_index >= 1:
                    new_centers.add(current_index)
        new_centers = list(new_centers)
        masked_indices[i][new_centers] = True

    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels

import os
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler
from tqdm import tqdm, trange
from transformers import PreTrainedModel, PreTrainedTokenizer, AdamW, get_linear_schedule_with_warmup
from typing import List, Dict, Tuple
import wandb
import time

def train(config, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if config['local_rank'] in [-1, 0]:
        tb_writer = SummaryWriter()

    config['train_batch_size'] = config['per_gpu_train_batch_size'] * max(1, config['n_gpu'])

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(train_dataset) if config['local_rank'] == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=config['train_batch_size'], collate_fn=collate
    )

    if config['max_steps'] > 0:
        t_total = config['max_steps']
        config['num_train_epochs'] = config['max_steps'] // (len(train_dataloader) // config['gradient_accumulation_steps']) + 1
    else:
        t_total = len(train_dataloader) // config['gradient_accumulation_steps'] * config['num_train_epochs']

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": config['weight_decay'],
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate'], eps=config['adam_epsilon'], betas=(config['beta1'],config['beta2']))
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=2000, num_training_steps=t_total
    )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", config['num_train_epochs'])
    logger.info("  Instantaneous batch size per GPU = %d", config['per_gpu_train_batch_size'])
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        config['train_batch_size']
        * config['gradient_accumulation_steps']
        * (torch.distributed.get_world_size() if config['local_rank'] != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", config['gradient_accumulation_steps'])
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(config['num_train_epochs']), desc="Epoch", disable=config['local_rank'] not in [-1, 0]
    )
    set_seed(config)  # Added here for reproducibility

    for epoch in train_iterator:
        epoch_start_time = time.time()
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=config['local_rank'] not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer, config) if config['mlm'] else (batch, batch)

            inputs = inputs.to(config['device'])
            labels = labels.to(config['device'])
            model.train()
            outputs = model(inputs, labels=labels) if config['mlm'] else model(inputs, labels=labels)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if config['n_gpu'] > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if config['gradient_accumulation_steps'] > 1:
                loss = loss / config['gradient_accumulation_steps']

            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % config['gradient_accumulation_steps'] == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # Log metrics to wandb
                wandb.log({"learning_rate": scheduler.get_last_lr()[0], "loss": loss.item(), "global_step": global_step})

                if config['local_rank'] in [-1, 0] and config['logging_steps'] > 0 and global_step % config['logging_steps'] == 0:
                    # Log metrics
                    if (
                        config['local_rank'] == -1 and config['evaluate_during_training']
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(config, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                            wandb.log({f"eval_{key}": value, "global_step": global_step})
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / config['logging_steps'], global_step)
                    logging_loss = tr_loss

                if config['local_rank'] in [-1, 0] and config['save_steps'] > 0 and global_step % config['save_steps'] == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    st = r"/kaggle/working/output"
                    output_dir = os.path.join(st, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(config, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(config, checkpoint_prefix)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if config['max_steps'] > 0 and global_step > config['max_steps']:
                epoch_iterator.close()
                break
        if config['max_steps'] > 0 and global_step > config['max_steps']:
            train_iterator.close()
            break
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
        # Log epoch time
        output_dir = r"/kaggle/working/output"
        logging.info(f'Epoch {epoch + 1}: Time {epoch_time:.4f}s')
        log_dir = os.path.join(output_dir, 'training_logs')
        os.makedirs(log_dir, exist_ok=True)
        file = os.path.join(log_dir,'log.txt')
        with open(file, 'a') as f:
            f.write(f"Epoch {epoch + 1}/{config['num_train_epochs']}:\n")
            f.write(f"  Epoch Time: {epoch_time}\n")

        # Log epoch time to wandb
        wandb.log({"epoch_time": epoch_time, "epoch": epoch + 1})

    if config['local_rank'] in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step


import os
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizer
from typing import List, Dict
import wandb

def evaluate(config, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = config['output_dir']

    eval_dataset = load_and_cache_examples(config, tokenizer, evaluate=True)

    if config['local_rank'] in [-1, 0]:
        os.makedirs(eval_output_dir, exist_ok=True)

    config['eval_batch_size'] = config['per_gpu_eval_batch_size'] * max(1, config['n_gpu'])
    # Note that DistributedSampler samples randomly

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=config['eval_batch_size'], collate_fn=collate
    )

    # multi-gpu evaluate
    if config['n_gpu'] > 1 and not isinstance(model, torch.nn.DataParallel):
        model = torch.nn.DataParallel(model)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", config['eval_batch_size'])
    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs, labels = mask_tokens(batch, tokenizer, config) if config['mlm'] else (batch, batch)
        inputs = inputs.to(config['device'])
        labels = labels.to(config['device'])

        with torch.no_grad():
            outputs = model(inputs, labels=labels) if config['mlm'] else model(inputs, labels=labels)
            lm_loss = outputs[0]
            eval_loss += lm_loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

    result = {"perplexity": perplexity.item()}

    # Log metrics to wandb
    wandb.log({"eval perplexity" : result})

    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
    with open(output_eval_file, "a") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result


import argparse
import os
import logging
import torch
import os
import logging
import torch

def main(config):
    # Handle checkpoint continuation
    if config['should_continue']:
        sorted_checkpoints = _sorted_checkpoints(config)
        if len(sorted_checkpoints) == 0:
            raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
        else:
            config['model_name_or_path'] = sorted_checkpoints[-1]

    output_dir = config.get('output_dir', './output')
    if (
        os.path.exists(output_dir)
        and os.listdir(output_dir)
        and config['do_train']
        and not config.get('overwrite_output_dir', False)
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                output_dir
            )
        )

    # Setup CUDA, GPU & distributed training
    if config.get('local_rank', -1) == -1 or config.get('no_cuda', False):
        device = torch.device("cuda:0" if torch.cuda.is_available() and not config.get('no_cuda', False) else "cpu")
        config['n_gpu'] = torch.cuda.device_count()
    else:
        torch.cuda.set_device(config.get('local_rank', 0))
        device = torch.device("cuda", config.get('local_rank', 0))
        torch.distributed.init_process_group(backend="nccl")
        config['n_gpu'] = 1
    config['device'] = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if config.get('local_rank', -1) in [-1, 0] else logging.WARN,
        filename = 'app.log'
    )
    logger = logging.getLogger(__name__)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        config.get('local_rank', -1),
        device,
        config['n_gpu'],
        bool(config.get('local_rank', -1) != -1),
        config.get('fp16', False),
    )

    # Set seed
    set_seed(config)

    # Load pretrained model and tokenizer
    if config.get('local_rank', -1) not in [-1, 0]:
        torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training download model & vocab

    config_class, model_class, tokenizer_class = MODEL_CLASSES['dna']
    config_obj = config_class.from_pretrained('prajjwal1/bert-tiny', cache_dir=config.get('cache_dir', None))

    tokenizer = tokenizer_class.from_pretrained('zhihan1996/DNA_bert_6', cache_dir=config.get('cache_dir', None))

    if config.get('block_size', 512) <= 0:
        config['block_size'] = 512
    else:
        config['block_size'] = min(config['block_size'], 512)

    if config.get('model_name_or_path'):
#         model = model_class.from_pretrained(
#             config['model_name_or_path'],
#             from_tf=bool(".ckpt" in config['model_name_or_path']),
#             config=config_obj,
#             cache_dir=config.get('cache_dir', None),
        pass
    else:
        logger.info("Training new model from scratch")
        model = model_class(config=config_obj)

    model.to(config['device'])

    if config.get('local_rank', -1) == 0:
        torch.distributed.barrier()

    logger.info("Training/evaluation parameters %s", config)

    # Training
    if config.get('do_train', False):
        if config.get('local_rank', -1) not in [-1, 0]:
            torch.distributed.barrier()

        train_dataset = load_and_cache_examples(config, tokenizer, evaluate=False)

        if config.get('local_rank', -1) == 0:
            torch.distributed.barrier()

        global_step, tr_loss = train(config, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    # Save and reload model
    if config.get('do_train', False) and (config.get('local_rank', -1) == -1 or torch.distributed.get_rank() == 0):
        if config.get('local_rank', -1) in [-1, 0]:
            os.makedirs(output_dir, exist_ok=True)

        logger.info("Saving model checkpoint to %s", output_dir)
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )
        model_to_save.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        torch.save(config, os.path.join(output_dir, "training_args.bin"))

        model = model_class.from_pretrained(output_dir)
        tokenizer = tokenizer_class.from_pretrained(output_dir)
        model.to(config['device'])

    # Evaluation
    results = {}
    if config.get('do_eval', False) and config.get('local_rank', -1) in [-1, 0]:
        checkpoints = [output_dir]
        if config.get('eval_all_checkpoints', False):
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

            model = model_class.from_pretrained(checkpoint)
            model.to(config['device'])
            result = evaluate(config, model, tokenizer, prefix=prefix)
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)

    return results

# Example configuration dictionary
config = {
    'line_by_line': True,
    'should_continue': False,#use if you have a checkpoint present or it will throw error
    'mlm': True,
    'mlm_probability': 0.15,
    'config_name': None,
    'tokenizer_name': None,
    'cache_dir': None,
    'block_size': 512,
    'do_train': True,
    'do_eval': True,
    'evaluate_during_training': True,
    'per_gpu_train_batch_size': 175,
    'per_gpu_eval_batch_size': 25,
    'gradient_accumulation_steps': 1,
    'learning_rate': 4e-4,
    'weight_decay': 0.01,
    'adam_epsilon': 1e-6,
    'beta1': 0.9,
    'beta2': 0.98,
    'max_grad_norm': 1.0,
    'num_train_epochs': 2000,
    'max_steps': -1,
    'warmup_steps': 100,
    'logging_steps': 200,
    'save_steps': 1000,
    'save_total_limit': 10,
    'eval_all_checkpoints': False,
    'no_cuda': False,
    'overwrite_output_dir': True,
    'overwrite_cache': False,
    'seed': 42,
    'n_process': 1,
    'fp16': False,
    'fp16_opt_level': 'O1',
    'local_rank': -1,
    'server_ip': '',
    'server_port': '',
    'output_dir': './output',
    'device':'cuda'
}

if __name__ == "__main__":
    main(config)
