File size: 2,119 Bytes
413d4d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- encoding: utf-8 -*-
'''
@File    :   text_tokenizer.py
@Time    :   2021/12/20 01:26:12
@Author  :   Ming Ding 
@Contact :   [email protected]
'''

# here put the import lib
import os
import sys
import math
import random
from copy import copy
from typing import List

import sentencepiece as spm
from . import sentencepiece_model_pb2 as model


class TextTokenizer:
    def __init__(self, model_path):
        self.proto = model.ModelProto()
        with open(model_path, 'rb') as fin:
            proto_str = fin.read()
            self.proto.ParseFromString(proto_str)
        self.refresh()
        
    def refresh(self):
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(model_proto=self.proto.SerializeToString())
        self.num_tokens = self.sp.vocab_size()

    def add_special_tokens(self, tokens):
        for token in tokens:
            new_token = model.ModelProto().SentencePiece()
            new_token.piece = token
            new_token.score = 0
            self.proto.pieces.append(new_token)
        self.refresh()
    
    def discourage_tokens(self, tokens):
        if isinstance(tokens, str): # single token
            tokens = [tokens]
        for token in tokens:
            for piece in self.proto.pieces:    
                if piece.piece == token:
                    piece.score = -100
        self.refresh()
    
    def discourage_ids(self, ids):
        if isinstance(ids, int):
            ids = [ids]
        for idx in ids:
            self.proto.pieces[idx].score = -100
        self.refresh()

    def encode(self, text):
        return self.sp.EncodeAsIds(text)

    def decode(self, ids: List[int]):
        return self.sp.DecodeIds(ids)

    def tokenize(self, text):
        return self.sp.EncodeAsPieces(text)

    def convert_tokens_to_ids(self, tokens):
        return [self.sp.PieceToId(token) for token in tokens]

    def convert_token_to_id(self, token):
        return self.sp.PieceToId(token)

    def convert_id_to_token(self, idx):
        return self.sp.IdToPiece(idx)
    
    def __len__(self):
        return self.num_tokens