|
""" |
|
Create vocabulary (train tokenizer) |
|
|
|
Authors: |
|
* Heng-Jui Chang 2022 |
|
""" |
|
|
|
import logging |
|
import os |
|
import tempfile |
|
from collections import Counter |
|
from typing import List, Union |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
__all__ = ["generate_basic_vocab", "generate_subword_vocab", "generate_vocab"] |
|
|
|
|
|
def generate_basic_vocab( |
|
mode: str, |
|
text_list: List[str], |
|
vocab_size: int = -1, |
|
coverage: float = 1.0, |
|
sort_vocab: bool = True, |
|
) -> List[str]: |
|
"""Generates basic vocabularies, including character and word-based vocabularies. |
|
|
|
Args: |
|
mode (str): Vocabulary type (character or word). |
|
text_list (List[str]): List of text data. |
|
vocab_size (int, optional): |
|
Vocabulary size, if not specified, vocab_size would be `coverage * actual vocab size`. Defaults to -1. |
|
coverage (float, optional): Vocabulary coverage. Defaults to 1.0. |
|
sort_vocab (bool, optional): Sort vocabularies alphabetically. Defaults to True. |
|
|
|
Returns: |
|
List[str]: A list of vocabularies. |
|
""" |
|
|
|
assert mode in {"character", "word"}, mode |
|
assert vocab_size == -1 or vocab_size > 0, vocab_size |
|
assert coverage > 0.0 and coverage <= 1.0, coverage |
|
|
|
logger.info( |
|
f"Generating vocab (type = {mode}, coverage = {coverage}) from {len(text_list)} sentences." |
|
) |
|
|
|
counter = Counter() |
|
|
|
for text in text_list: |
|
if mode == "character": |
|
counter.update(text) |
|
if mode == "word": |
|
counter.update(text.split()) |
|
|
|
if vocab_size < 0: |
|
vocab_size = int(len(counter) * coverage) |
|
else: |
|
vocab_size = min(vocab_size, len(counter)) |
|
|
|
if vocab_size < len(counter): |
|
vocab_list = sorted(counter.keys(), key=lambda k: counter[k], reverse=True) |
|
vocab_list = vocab_list[:vocab_size] |
|
else: |
|
vocab_list = list(counter.keys()) |
|
|
|
if sort_vocab: |
|
vocab_list = sorted(vocab_list) |
|
|
|
logger.info(f"Generated {vocab_size} {mode} vocabularies.") |
|
|
|
return vocab_list |
|
|
|
|
|
def generate_subword_vocab( |
|
text_list: List[str] = None, |
|
text_file: str = None, |
|
output_file: str = None, |
|
vocab_size: int = 1000, |
|
character_coverage: float = 1.0, |
|
) -> str: |
|
"""Generates subword vocabularies based on `sentencepiece`. |
|
|
|
Args: |
|
text_list (List[str], optional): List of text data. Defaults to None. |
|
text_file (str, optional): Path to text data. Defaults to None. |
|
output_file (str, optional): Path to save trained subword vocabularies. Defaults to "". |
|
vocab_size (int, optional): Vocabulary size. Defaults to 8000. |
|
character_coverage (float, optional): Coverage of characters in text data. Defaults to 1.0. |
|
|
|
Raises: |
|
ImportError: If `sentencepiece` is not installed. |
|
|
|
Returns: |
|
str: Path to `${output_file}.model`. |
|
""" |
|
|
|
try: |
|
import sentencepiece as splib |
|
except ImportError: |
|
raise ImportError( |
|
"`sentencepiece` cannot be imported, please run `pip install sentencepiece` first" |
|
) |
|
|
|
assert output_file is not None |
|
output_file = str(output_file) |
|
assert vocab_size > 0, vocab_size |
|
|
|
cmd = ( |
|
"--input={} --model_prefix={} --model_type=unigram " |
|
"--vocab_size={} --character_coverage={} " |
|
"--pad_id=0 --eos_id=1 --unk_id=2 --bos_id=-1 " |
|
"--eos_piece=<eos> --remove_extra_whitespaces=true " |
|
) |
|
|
|
if text_list is not None: |
|
assert isinstance(text_list, list) |
|
assert isinstance(text_list[0], str) |
|
|
|
logger.info( |
|
f"Generating vocab (type = subword, coverage = {character_coverage}) from {len(text_list)} sentences." |
|
) |
|
|
|
with tempfile.TemporaryDirectory() as directory: |
|
input_file = os.path.join(directory, "text.txt") |
|
with open(input_file, "w") as fp: |
|
for text in text_list: |
|
fp.write(text + "\n") |
|
|
|
cmd = cmd.format( |
|
input_file, |
|
output_file, |
|
vocab_size, |
|
character_coverage, |
|
) |
|
splib.SentencePieceTrainer.Train(cmd) |
|
|
|
if text_file is not None: |
|
logger.info( |
|
f"Generating vocab (type = subword, coverage = {character_coverage}) from {text_file}" |
|
) |
|
|
|
cmd = cmd.format( |
|
text_file, |
|
output_file, |
|
vocab_size, |
|
character_coverage, |
|
) |
|
splib.SentencePieceTrainer.Train(cmd) |
|
|
|
return output_file + ".model" |
|
|
|
|
|
def generate_vocab( |
|
mode: str, |
|
text_list: List[str] = None, |
|
text_file: str = None, |
|
read_lines: int = 10000000, |
|
**vocab_args, |
|
) -> Union[List[str], str]: |
|
"""Generates vocabularies given text data. |
|
|
|
Args: |
|
mode (str): Vocabulary type |
|
text_list (List[str], optional): List of text data. Defaults to None. |
|
text_file (str, optional): Path to text data. Defaults to None. |
|
read_lines (int, optional): Maximum lines to read from `text_file`. Defaults to 10000000. |
|
vocab_args: |
|
if :code:`mode != subword`, arguments for :obj:`generate_basic_vocab` |
|
if :code:`mode == subword`, arguments for :obj:`generate_subword_vocab` |
|
|
|
Returns: |
|
Union[List[str], str]: A list of vocabularies or a path to `.vocab` file. |
|
""" |
|
|
|
if text_list is None and mode in {"character", "word", "phoneme"}: |
|
assert isinstance(text_file, str) |
|
with open(text_file, "r", encoding="UTF-8") as fp: |
|
text_list = [ |
|
line.strip("\r\n ") for i, line in enumerate(fp) if i < read_lines |
|
] |
|
|
|
if mode == "character": |
|
return generate_basic_vocab("character", text_list, **vocab_args) |
|
if mode in {"word", "phoneme"}: |
|
return generate_basic_vocab("word", text_list, **vocab_args) |
|
if mode == "subword": |
|
return generate_subword_vocab( |
|
text_list=text_list, text_file=text_file, **vocab_args |
|
) |
|
else: |
|
raise ValueError(f"Unsupported mode (vocabulary type): {mode}") |
|
|