File size: 991 Bytes
927768a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import BartTokenizer, BartForConditionalGeneration


def main():
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
    bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
    num_added_tokens = tokenizer.add_special_tokens({
        "additional_special_tokens": ["<idiom>", "</idiom>"],  # beginning and end of an idiom
    })
    print(num_added_tokens)
    print(tokenizer.additional_special_tokens)  # more special tokens are added here
    # and then you should resize the embedding table of your model
    print(bart.model.shared.weight.shape)  # before
    bart.resize_token_embeddings(len(tokenizer))
    print(bart.model.shared.weight.shape)  # after


if __name__ == '__main__':
    main()

"""
2
['<idiom>', '</idiom>']
torch.Size([50265, 768])
torch.Size([50267, 768])  # you can see that 2 more embedding vectors have been added here.
later, you may want to save the tokenizer after you add the idiom special tokens.
"""