File size: 3,076 Bytes
1bf3d62
 
 
 
 
 
 
 
25e310b
1bf3d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539e83f
1bf3d62
 
539e83f
1bf3d62
539e83f
1bf3d62
 
 
25e310b
1bf3d62
539e83f
1bf3d62
 
 
 
 
 
ec156ad
1bf3d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import torch.cuda
import wandb
import argparse
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from termcolor import colored
from transformers import BertForMaskedLM, BertTokenizer
from idiomify.datamodules import IdiomifyDataModule
from idiomify.fetchers import fetch_config, fetch_idioms
from idiomify.models import Alpha, Gamma
from idiomify.paths import ROOT_DIR
from idiomify import tensors as T


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="alpha")
    parser.add_argument("--ver", type=str, default="eng2eng")
    parser.add_argument("--num_workers", type=int, default=os.cpu_count())
    parser.add_argument("--log_every_n_steps", type=int, default=1)
    parser.add_argument("--fast_dev_run", action="store_true", default=False)
    parser.add_argument("--upload", dest='upload', action='store_true', default=False)
    args = parser.parse_args()
    config = fetch_config()[args.model][args.ver]
    config.update(vars(args))
    if not config['upload']:
        print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))

    # prepare arguments
    mlm = BertForMaskedLM.from_pretrained(config['bert'])
    tokenizer = BertTokenizer.from_pretrained(config['bert'])
    idioms = fetch_idioms(config['idioms_ver'])
    idiom2subwords = T.idiom2subwords(idioms, tokenizer, config['k'])
    # choose the model to train
    if config['model'] == Alpha.name():
        rd = Alpha(mlm, idiom2subwords, config['k'], config['lr'])
    elif config['model'] == Gamma.name():
        rd = Gamma(mlm, idiom2subwords, config['k'], config['lr'])
    else:
        raise ValueError
    # prepare datamodule
    datamodule = IdiomifyDataModule(config, tokenizer, idioms)

    with wandb.init(entity="eubinecto", project="idiomify-demo", config=config) as run:
        logger = WandbLogger(log_model=False)
        trainer = pl.Trainer(max_epochs=config['max_epochs'],
                             fast_dev_run=config['fast_dev_run'],
                             log_every_n_steps=config['log_every_n_steps'],
                             gpus=torch.cuda.device_count(),
                             default_root_dir=str(ROOT_DIR),
                             enable_checkpointing=False,
                             logger=logger)
        # start training
        trainer.fit(model=rd, datamodule=datamodule)
        # upload the model to wandb only if the training is properly done  #
        if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
            ckpt_path = ROOT_DIR / "rd.ckpt"
            trainer.save_checkpoint(str(ckpt_path))
            artifact = wandb.Artifact(name=config['model'], type="model", metadata=config)
            artifact.add_file(str(ckpt_path))
            run.log_artifact(artifact, aliases=["latest", config['ver']])
            os.remove(str(ckpt_path))  # make sure you remove it after you are done with uploading it


if __name__ == '__main__':
    main()