proteinglm-1b-mlm / tokenization_proteinglm.py
Bo1015's picture
Upload 9 files
37b8123 verified
"""Tokenization classes for ProteinGLM."""
import os
from typing import List, Optional, Union, Dict, Any
from torch import TensorType
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
def load_vocab_file(vocab_file: str) -> List[str]:
with open(vocab_file, "r") as f:
lines = f.read().splitlines()
return [line.strip() for line in lines]
class ProteinGLMTokenizer(PreTrainedTokenizer):
"""
Constructs a ProteinGLM tokenizer.
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__(
self,
vocab_file: str,
unk_token: str = "<unk>",
pad_token: str = "<pad>",
mask_token: str = "<mask>",
eos_token: str = "<eos>",
model_max_length: int = 2048,
additional_special_tokens: Optional[List[str]] = None,
**kwargs,
):
self.all_tokens = load_vocab_file(vocab_file)
self._id_to_token = dict(enumerate(self.all_tokens))
self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
if additional_special_tokens is None:
additional_special_tokens = ['<pad>', '<mask>', '<gmask>', '<smask>', '<eod>', '<sop>', '<eop>', '<eos>', '<unk>']
super().__init__(
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
eos_token=eos_token,
model_max_length=model_max_length,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.unique_no_split_tokens = self.all_tokens
self._update_trie(self.unique_no_split_tokens)
def _convert_id_to_token(self, index: int) -> str:
return self._id_to_token.get(index, self.unk_token)
def _convert_token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
def _tokenize(self, text: str, **kwargs) -> List[str]:
return text.split()
def get_vocab(self) -> dict:
base_vocab = self._token_to_id.copy()
base_vocab.update(self.added_tokens_encoder)
return base_vocab
def token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
def id_to_token(self, index: int) -> str:
return self._id_to_token.get(index, self.unk_token)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.eos_token_id]
if token_ids_1 is None:
if self.eos_token_id is None:
return token_ids_0
else:
return token_ids_0 + sep
elif self.eos_token_id is None:
raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!")
return token_ids_0 + sep + token_ids_1 + sep # Multiple inputs always have an EOS token
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "tokenizer.model")
with open(vocab_file, "w") as f:
f.write("\n".join(self.all_tokens))
return (vocab_file,)
@property
def vocab_size(self) -> int:
return len(self.all_tokens)
def apply_chat_template(
self,
query,
add_generation_prompt: bool = True,
tokenize: bool = True,
padding: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_dict: bool = False,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
add_special_tokens: bool = True,
**kwargs,
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
generation_prompt = "<gmask><sop><eos>"
if isinstance(query, str):
query = [query]
prompt_query = []
if add_generation_prompt:
for each in query:
assert isinstance(each, str)
prompt_query.append(generation_prompt+each)
else:
prompt_query = query
if tokenize:
output = self.batch_encode_plus(
prompt_query,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
is_split_into_words=True,
add_special_tokens=False
)
if return_dict:
return output
else:
return output["input_ids"]
else:
return prompt_query