lmzjms's picture
Upload 1162 files
0b32ad6 verified
"""
Create vocabulary (train tokenizer)
Authors:
* Heng-Jui Chang 2022
"""
import logging
import os
import tempfile
from collections import Counter
from typing import List, Union
logger = logging.getLogger(__name__)
__all__ = ["generate_basic_vocab", "generate_subword_vocab", "generate_vocab"]
def generate_basic_vocab(
mode: str,
text_list: List[str],
vocab_size: int = -1,
coverage: float = 1.0,
sort_vocab: bool = True,
) -> List[str]:
"""Generates basic vocabularies, including character and word-based vocabularies.
Args:
mode (str): Vocabulary type (character or word).
text_list (List[str]): List of text data.
vocab_size (int, optional):
Vocabulary size, if not specified, vocab_size would be `coverage * actual vocab size`. Defaults to -1.
coverage (float, optional): Vocabulary coverage. Defaults to 1.0.
sort_vocab (bool, optional): Sort vocabularies alphabetically. Defaults to True.
Returns:
List[str]: A list of vocabularies.
"""
assert mode in {"character", "word"}, mode
assert vocab_size == -1 or vocab_size > 0, vocab_size
assert coverage > 0.0 and coverage <= 1.0, coverage
logger.info(
f"Generating vocab (type = {mode}, coverage = {coverage}) from {len(text_list)} sentences."
)
counter = Counter()
for text in text_list:
if mode == "character":
counter.update(text)
if mode == "word":
counter.update(text.split())
if vocab_size < 0:
vocab_size = int(len(counter) * coverage)
else:
vocab_size = min(vocab_size, len(counter))
if vocab_size < len(counter):
vocab_list = sorted(counter.keys(), key=lambda k: counter[k], reverse=True)
vocab_list = vocab_list[:vocab_size]
else:
vocab_list = list(counter.keys())
if sort_vocab:
vocab_list = sorted(vocab_list)
logger.info(f"Generated {vocab_size} {mode} vocabularies.")
return vocab_list
def generate_subword_vocab(
text_list: List[str] = None,
text_file: str = None,
output_file: str = None,
vocab_size: int = 1000,
character_coverage: float = 1.0,
) -> str:
"""Generates subword vocabularies based on `sentencepiece`.
Args:
text_list (List[str], optional): List of text data. Defaults to None.
text_file (str, optional): Path to text data. Defaults to None.
output_file (str, optional): Path to save trained subword vocabularies. Defaults to "".
vocab_size (int, optional): Vocabulary size. Defaults to 8000.
character_coverage (float, optional): Coverage of characters in text data. Defaults to 1.0.
Raises:
ImportError: If `sentencepiece` is not installed.
Returns:
str: Path to `${output_file}.model`.
"""
try:
import sentencepiece as splib
except ImportError:
raise ImportError(
"`sentencepiece` cannot be imported, please run `pip install sentencepiece` first"
)
assert output_file is not None
output_file = str(output_file)
assert vocab_size > 0, vocab_size
cmd = (
"--input={} --model_prefix={} --model_type=unigram "
"--vocab_size={} --character_coverage={} "
"--pad_id=0 --eos_id=1 --unk_id=2 --bos_id=-1 "
"--eos_piece=<eos> --remove_extra_whitespaces=true "
)
if text_list is not None:
assert isinstance(text_list, list)
assert isinstance(text_list[0], str)
logger.info(
f"Generating vocab (type = subword, coverage = {character_coverage}) from {len(text_list)} sentences."
)
with tempfile.TemporaryDirectory() as directory:
input_file = os.path.join(directory, "text.txt")
with open(input_file, "w") as fp:
for text in text_list:
fp.write(text + "\n")
cmd = cmd.format(
input_file,
output_file,
vocab_size,
character_coverage,
)
splib.SentencePieceTrainer.Train(cmd)
if text_file is not None:
logger.info(
f"Generating vocab (type = subword, coverage = {character_coverage}) from {text_file}"
)
cmd = cmd.format(
text_file,
output_file,
vocab_size,
character_coverage,
)
splib.SentencePieceTrainer.Train(cmd)
return output_file + ".model"
def generate_vocab(
mode: str,
text_list: List[str] = None,
text_file: str = None,
read_lines: int = 10000000,
**vocab_args,
) -> Union[List[str], str]:
"""Generates vocabularies given text data.
Args:
mode (str): Vocabulary type
text_list (List[str], optional): List of text data. Defaults to None.
text_file (str, optional): Path to text data. Defaults to None.
read_lines (int, optional): Maximum lines to read from `text_file`. Defaults to 10000000.
vocab_args:
if :code:`mode != subword`, arguments for :obj:`generate_basic_vocab`
if :code:`mode == subword`, arguments for :obj:`generate_subword_vocab`
Returns:
Union[List[str], str]: A list of vocabularies or a path to `.vocab` file.
"""
if text_list is None and mode in {"character", "word", "phoneme"}:
assert isinstance(text_file, str)
with open(text_file, "r", encoding="UTF-8") as fp:
text_list = [
line.strip("\r\n ") for i, line in enumerate(fp) if i < read_lines
]
if mode == "character":
return generate_basic_vocab("character", text_list, **vocab_args)
if mode in {"word", "phoneme"}:
return generate_basic_vocab("word", text_list, **vocab_args)
if mode == "subword":
return generate_subword_vocab(
text_list=text_list, text_file=text_file, **vocab_args
)
else:
raise ValueError(f"Unsupported mode (vocabulary type): {mode}")