Spaces:
Runtime error
Runtime error
File size: 2,119 Bytes
413d4d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# -*- encoding: utf-8 -*-
'''
@File : text_tokenizer.py
@Time : 2021/12/20 01:26:12
@Author : Ming Ding
@Contact : [email protected]
'''
# here put the import lib
import os
import sys
import math
import random
from copy import copy
from typing import List
import sentencepiece as spm
from . import sentencepiece_model_pb2 as model
class TextTokenizer:
def __init__(self, model_path):
self.proto = model.ModelProto()
with open(model_path, 'rb') as fin:
proto_str = fin.read()
self.proto.ParseFromString(proto_str)
self.refresh()
def refresh(self):
self.sp = spm.SentencePieceProcessor()
self.sp.Load(model_proto=self.proto.SerializeToString())
self.num_tokens = self.sp.vocab_size()
def add_special_tokens(self, tokens):
for token in tokens:
new_token = model.ModelProto().SentencePiece()
new_token.piece = token
new_token.score = 0
self.proto.pieces.append(new_token)
self.refresh()
def discourage_tokens(self, tokens):
if isinstance(tokens, str): # single token
tokens = [tokens]
for token in tokens:
for piece in self.proto.pieces:
if piece.piece == token:
piece.score = -100
self.refresh()
def discourage_ids(self, ids):
if isinstance(ids, int):
ids = [ids]
for idx in ids:
self.proto.pieces[idx].score = -100
self.refresh()
def encode(self, text):
return self.sp.EncodeAsIds(text)
def decode(self, ids: List[int]):
return self.sp.DecodeIds(ids)
def tokenize(self, text):
return self.sp.EncodeAsPieces(text)
def convert_tokens_to_ids(self, tokens):
return [self.sp.PieceToId(token) for token in tokens]
def convert_token_to_id(self, token):
return self.sp.PieceToId(token)
def convert_id_to_token(self, idx):
return self.sp.IdToPiece(idx)
def __len__(self):
return self.num_tokens |