Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import argparse | |
from collections import Counter | |
import logging | |
from pathlib import Path | |
import sys | |
from typing import List | |
from typing import Optional | |
from funasr_detach.utils.cli_utils import get_commandline_args | |
from funasr_detach.tokenizer.build_tokenizer import build_tokenizer | |
from funasr_detach.tokenizer.cleaner import TextCleaner | |
from funasr_detach.tokenizer.phoneme_tokenizer import g2p_classes | |
from funasr_detach.utils.types import str2bool | |
from funasr_detach.utils.types import str_or_none | |
def field2slice(field: Optional[str]) -> slice: | |
"""Convert field string to slice | |
Note that field string accepts 1-based integer. | |
Examples: | |
>>> field2slice("1-") | |
slice(0, None, None) | |
>>> field2slice("1-3") | |
slice(0, 3, None) | |
>>> field2slice("-3") | |
slice(None, 3, None) | |
""" | |
field = field.strip() | |
try: | |
if "-" in field: | |
# e.g. "2-" or "2-5" or "-7" | |
s1, s2 = field.split("-", maxsplit=1) | |
if s1.strip() == "": | |
s1 = None | |
else: | |
s1 = int(s1) | |
if s1 == 0: | |
raise ValueError("1-based string") | |
if s2.strip() == "": | |
s2 = None | |
else: | |
s2 = int(s2) | |
else: | |
# e.g. "2" | |
s1 = int(field) | |
s2 = s1 + 1 | |
if s1 == 0: | |
raise ValueError("must be 1 or more value") | |
except ValueError: | |
raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}") | |
if s1 is None: | |
slic = slice(None, s2) | |
else: | |
# -1 because of 1-based integer following "cut" command | |
# e.g "1-3" -> slice(0, 3) | |
slic = slice(s1 - 1, s2) | |
return slic | |
def tokenize( | |
input: str, | |
output: str, | |
field: Optional[str], | |
delimiter: Optional[str], | |
token_type: str, | |
space_symbol: str, | |
non_linguistic_symbols: Optional[str], | |
bpemodel: Optional[str], | |
log_level: str, | |
write_vocabulary: bool, | |
vocabulary_size: int, | |
remove_non_linguistic_symbols: bool, | |
cutoff: int, | |
add_symbol: List[str], | |
cleaner: Optional[str], | |
g2p: Optional[str], | |
): | |
logging.basicConfig( | |
level=log_level, | |
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
) | |
if input == "-": | |
fin = sys.stdin | |
else: | |
fin = Path(input).open("r", encoding="utf-8") | |
if output == "-": | |
fout = sys.stdout | |
else: | |
p = Path(output) | |
p.parent.mkdir(parents=True, exist_ok=True) | |
fout = p.open("w", encoding="utf-8") | |
cleaner = TextCleaner(cleaner) | |
tokenizer = build_tokenizer( | |
token_type=token_type, | |
bpemodel=bpemodel, | |
delimiter=delimiter, | |
space_symbol=space_symbol, | |
non_linguistic_symbols=non_linguistic_symbols, | |
remove_non_linguistic_symbols=remove_non_linguistic_symbols, | |
g2p_type=g2p, | |
) | |
counter = Counter() | |
if field is not None: | |
field = field2slice(field) | |
for line in fin: | |
line = line.rstrip() | |
if field is not None: | |
# e.g. field="2-" | |
# uttidA hello world!! -> hello world!! | |
tokens = line.split(delimiter) | |
tokens = tokens[field] | |
if delimiter is None: | |
line = " ".join(tokens) | |
else: | |
line = delimiter.join(tokens) | |
line = cleaner(line) | |
tokens = tokenizer.text2tokens(line) | |
if not write_vocabulary: | |
fout.write(" ".join(tokens) + "\n") | |
else: | |
for t in tokens: | |
counter[t] += 1 | |
if not write_vocabulary: | |
return | |
## FIXME | |
## del duplicate add_symbols in counter | |
for symbol_and_id in add_symbol: | |
# e.g symbol="<blank>:0" | |
try: | |
symbol, idx = symbol_and_id.split(":") | |
except ValueError: | |
raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}") | |
symbol = symbol.strip() | |
if symbol in counter: | |
del counter[symbol] | |
# ======= write_vocabulary mode from here ======= | |
# Sort by the number of occurrences in descending order | |
# and filter lower frequency words than cutoff value | |
words_and_counts = list( | |
filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1])) | |
) | |
# Restrict the vocabulary size | |
if vocabulary_size > 0: | |
if vocabulary_size < len(add_symbol): | |
raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}") | |
words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)] | |
# Parse the values of --add_symbol | |
for symbol_and_id in add_symbol: | |
# e.g symbol="<blank>:0" | |
try: | |
symbol, idx = symbol_and_id.split(":") | |
idx = int(idx) | |
except ValueError: | |
raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}") | |
symbol = symbol.strip() | |
# e.g. idx=0 -> append as the first symbol | |
# e.g. idx=-1 -> append as the last symbol | |
if idx < 0: | |
idx = len(words_and_counts) + 1 + idx | |
words_and_counts.insert(idx, (symbol, None)) | |
# Write words | |
for w, c in words_and_counts: | |
fout.write(w + "\n") | |
# Logging | |
total_count = sum(counter.values()) | |
invocab_count = sum(c for w, c in words_and_counts if c is not None) | |
logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %") | |
def get_parser() -> argparse.ArgumentParser: | |
parser = argparse.ArgumentParser( | |
description="Tokenize texts", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument( | |
"--log_level", | |
type=lambda x: x.upper(), | |
default="INFO", | |
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), | |
help="The verbose level of logging", | |
) | |
parser.add_argument( | |
"--input", "-i", required=True, help="Input text. - indicates sys.stdin" | |
) | |
parser.add_argument( | |
"--output", "-o", required=True, help="Output text. - indicates sys.stdout" | |
) | |
parser.add_argument( | |
"--field", | |
"-f", | |
help="The target columns of the input text as 1-based integer. e.g 2-", | |
) | |
parser.add_argument( | |
"--token_type", | |
"-t", | |
default="char", | |
choices=["char", "bpe", "word", "phn"], | |
help="Token type", | |
) | |
parser.add_argument("--delimiter", "-d", default=None, help="The delimiter") | |
parser.add_argument("--space_symbol", default="<space>", help="The space symbol") | |
parser.add_argument("--bpemodel", default=None, help="The bpemodel file path") | |
parser.add_argument( | |
"--non_linguistic_symbols", | |
type=str_or_none, | |
help="non_linguistic_symbols file path", | |
) | |
parser.add_argument( | |
"--remove_non_linguistic_symbols", | |
type=str2bool, | |
default=False, | |
help="Remove non-language-symbols from tokens", | |
) | |
parser.add_argument( | |
"--cleaner", | |
type=str_or_none, | |
choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"], | |
default=None, | |
help="Apply text cleaning", | |
) | |
parser.add_argument( | |
"--g2p", | |
type=str_or_none, | |
choices=g2p_classes, | |
default=None, | |
help="Specify g2p method if --token_type=phn", | |
) | |
group = parser.add_argument_group("write_vocabulary mode related") | |
group.add_argument( | |
"--write_vocabulary", | |
type=str2bool, | |
default=False, | |
help="Write tokens list instead of tokenized text per line", | |
) | |
group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size") | |
group.add_argument( | |
"--cutoff", | |
default=0, | |
type=int, | |
help="cut-off frequency used for write-vocabulary mode", | |
) | |
group.add_argument( | |
"--add_symbol", | |
type=str, | |
default=[], | |
action="append", | |
help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'", | |
) | |
return parser | |
def main(cmd=None): | |
print(get_commandline_args(), file=sys.stderr) | |
parser = get_parser() | |
args = parser.parse_args(cmd) | |
kwargs = vars(args) | |
tokenize(**kwargs) | |
if __name__ == "__main__": | |
main() | |