File size: 3,449 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from tencentpretrain.utils.dataset import *
from tencentpretrain.utils.dataloader import *
from tencentpretrain.utils.act_fun import *
from tencentpretrain.utils.optimizers import *
from tencentpretrain.utils.adversarial import *


str2tokenizer = {"char": CharTokenizer, "space": SpaceTokenizer, "bert": BertTokenizer,
                 "bpe": BPETokenizer, "xlmroberta": XLMRobertaTokenizer, "image": ImageTokenizer,
                 "text_image": TextImageTokenizer, "virtual": VirtualTokenizer}
str2dataset = {"bert": BertDataset, "lm": LmDataset, "mlm": MlmDataset,
               "bilm": BilmDataset, "albert": AlbertDataset, "mt": MtDataset,
               "t5": T5Dataset, "gsg": GsgDataset, "bart": BartDataset,
               "cls": ClsDataset, "prefixlm": PrefixlmDataset, "cls_mlm": ClsMlmDataset,
               "vit": VitDataset, "vilt": ViltDataset, "clip": ClipDataset, "s2t": S2tDataset,
               "beit":BeitDataset, "dalle": DalleDataset}
str2dataloader = {"bert": BertDataloader, "lm": LmDataloader, "mlm": MlmDataloader,
                  "bilm": BilmDataloader, "albert": AlbertDataloader, "mt": MtDataloader,
                  "t5": T5Dataloader, "gsg": GsgDataloader, "bart": BartDataloader,
                  "cls": ClsDataloader, "prefixlm": PrefixlmDataloader, "cls_mlm": ClsMlmDataloader,
                  "vit": VitDataloader, "vilt": ViltDataloader, "clip": ClipDataloader, "s2t": S2tDataloader,
                  "beit":BeitDataloader, "dalle": DalleDataloader}

str2act = {"gelu": gelu, "gelu_fast": gelu_fast, "relu": relu, "silu": silu, "linear": linear}

str2optimizer = {"adamw": AdamW, "adafactor": Adafactor}

str2scheduler = {"linear": get_linear_schedule_with_warmup, "cosine": get_cosine_schedule_with_warmup,
                 "cosine_with_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
                 "polynomial": get_polynomial_decay_schedule_with_warmup,
                 "constant": get_constant_schedule, "constant_with_warmup": get_constant_schedule_with_warmup,
                 "inverse_sqrt": get_inverse_square_root_schedule_with_warmup, "tri_stage": get_tri_stage_schedule}

str2adv = {"fgm": FGM, "pgd": PGD}

__all__ = ["CharTokenizer", "SpaceTokenizer", "BertTokenizer", "BPETokenizer", "XLMRobertaTokenizer",
           "ImageTokenizer", "TextImageTokenizer", "str2tokenizer",
           "BertDataset", "LmDataset", "MlmDataset", "BilmDataset",
           "AlbertDataset", "MtDataset", "T5Dataset", "GsgDataset",
           "BartDataset", "ClsDataset", "PrefixlmDataset", "ClsMlmDataset",
           "VitDataset", "ViltDataset", "ClipDataset", "BeitDataset", "str2dataset",
           "BertDataloader", "LmDataloader", "MlmDataloader", "BilmDataloader",
           "AlbertDataloader", "MtDataloader", "T5Dataloader", "GsgDataloader",
           "BartDataloader", "ClsDataloader", "PrefixlmDataloader", "ClsMlmDataloader",
           "VitDataloader", "ViltDataloader", "ClipDataloader", "BeitDataloader", "str2dataloader",
           "gelu", "gelu_fast", "relu", "silu", "linear", "str2act",
           "AdamW", "Adafactor", "str2optimizer",
           "get_linear_schedule_with_warmup", "get_cosine_schedule_with_warmup",
           "get_cosine_with_hard_restarts_schedule_with_warmup",
           "get_polynomial_decay_schedule_with_warmup",
           "get_constant_schedule", "get_constant_schedule_with_warmup", "str2scheduler",
           "FGM", "PGD", "str2adv"]