PoTaTo721's picture
Upload Fish-Agent Demo
4f6613a
# -*- coding: utf-8 -*-
"""基本方法
创建中文数字系统 方法
中文字符串 <=> 数字串 方法
数字串 <=> 中文字符串 方法
"""
__author__ = "Zhiyang Zhou <[email protected]>"
__data__ = "2019-05-02"
from fish_speech.text.chn_text_norm.basic_class import *
from fish_speech.text.chn_text_norm.basic_constant import *
def create_system(numbering_type=NUMBERING_TYPES[1]):
"""
根据数字系统类型返回创建相应的数字系统,默认为 mid
NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
返回对应的数字系统
"""
# chinese number units of '亿' and larger
all_larger_units = zip(
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL,
)
larger_units = [
CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)
]
# chinese number units of '十, 百, 千, 万'
all_smaller_units = zip(
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL,
)
smaller_units = [
CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)
]
# digis
chinese_digis = zip(
CHINESE_DIGIS,
CHINESE_DIGIS,
BIG_CHINESE_DIGIS_SIMPLIFIED,
BIG_CHINESE_DIGIS_TRADITIONAL,
)
digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
# symbols
positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
# sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
system = NumberSystem()
system.units = smaller_units + larger_units
system.digits = digits
system.math = MathSymbol(positive_cn, negative_cn, point_cn)
# system.symbols = OtherSymbol(sil_cn)
return system
def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
def get_symbol(char, system):
for u in system.units:
if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
return u
for d in system.digits:
if char in [
d.traditional,
d.simplified,
d.big_s,
d.big_t,
d.alt_s,
d.alt_t,
]:
return d
for m in system.math:
if char in [m.traditional, m.simplified]:
return m
def string2symbols(chinese_string, system):
int_string, dec_string = chinese_string, ""
for p in [system.math.point.simplified, system.math.point.traditional]:
if p in chinese_string:
int_string, dec_string = chinese_string.split(p)
break
return [get_symbol(c, system) for c in int_string], [
get_symbol(c, system) for c in dec_string
]
def correct_symbols(integer_symbols, system):
"""
一百八 to 一百八十
一亿一千三百万 to 一亿 一千万 三百万
"""
if integer_symbols and isinstance(integer_symbols[0], CNU):
if integer_symbols[0].power == 1:
integer_symbols = [system.digits[1]] + integer_symbols
if len(integer_symbols) > 1:
if isinstance(integer_symbols[-1], CND) and isinstance(
integer_symbols[-2], CNU
):
integer_symbols.append(
CNU(integer_symbols[-2].power - 1, None, None, None, None)
)
result = []
unit_count = 0
for s in integer_symbols:
if isinstance(s, CND):
result.append(s)
unit_count = 0
elif isinstance(s, CNU):
current_unit = CNU(s.power, None, None, None, None)
unit_count += 1
if unit_count == 1:
result.append(current_unit)
elif unit_count > 1:
for i in range(len(result)):
if (
isinstance(result[-i - 1], CNU)
and result[-i - 1].power < current_unit.power
):
result[-i - 1] = CNU(
result[-i - 1].power + current_unit.power,
None,
None,
None,
None,
)
return result
def compute_value(integer_symbols):
"""
Compute the value.
When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
e.g. '两千万' = 2000 * 10000 not 2000 + 10000
"""
value = [0]
last_power = 0
for s in integer_symbols:
if isinstance(s, CND):
value[-1] = s.value
elif isinstance(s, CNU):
value[-1] *= pow(10, s.power)
if s.power > last_power:
value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
last_power = s.power
value.append(0)
return sum(value)
system = create_system(numbering_type)
int_part, dec_part = string2symbols(chinese_string, system)
int_part = correct_symbols(int_part, system)
int_str = str(compute_value(int_part))
dec_str = "".join([str(d.value) for d in dec_part])
if dec_part:
return "{0}.{1}".format(int_str, dec_str)
else:
return int_str
def num2chn(
number_string,
numbering_type=NUMBERING_TYPES[1],
big=False,
traditional=False,
alt_zero=False,
alt_one=False,
alt_two=True,
use_zeros=True,
use_units=True,
):
def get_value(value_string, use_zeros=True):
striped_string = value_string.lstrip("0")
# record nothing if all zeros
if not striped_string:
return []
# record one digits
elif len(striped_string) == 1:
if use_zeros and len(value_string) != len(striped_string):
return [system.digits[0], system.digits[int(striped_string)]]
else:
return [system.digits[int(striped_string)]]
# recursively record multiple digits
else:
result_unit = next(
u for u in reversed(system.units) if u.power < len(striped_string)
)
result_string = value_string[: -result_unit.power]
return (
get_value(result_string)
+ [result_unit]
+ get_value(striped_string[-result_unit.power :])
)
system = create_system(numbering_type)
int_dec = number_string.split(".")
if len(int_dec) == 1:
int_string = int_dec[0]
dec_string = ""
elif len(int_dec) == 2:
int_string = int_dec[0]
dec_string = int_dec[1]
else:
raise ValueError(
"invalid input num string with more than one dot: {}".format(number_string)
)
if use_units and len(int_string) > 1:
result_symbols = get_value(int_string)
else:
result_symbols = [system.digits[int(c)] for c in int_string]
dec_symbols = [system.digits[int(c)] for c in dec_string]
if dec_string:
result_symbols += [system.math.point] + dec_symbols
if alt_two:
liang = CND(
2,
system.digits[2].alt_s,
system.digits[2].alt_t,
system.digits[2].big_s,
system.digits[2].big_t,
)
for i, v in enumerate(result_symbols):
if isinstance(v, CND) and v.value == 2:
next_symbol = (
result_symbols[i + 1] if i < len(result_symbols) - 1 else None
)
previous_symbol = result_symbols[i - 1] if i > 0 else None
if isinstance(next_symbol, CNU) and isinstance(
previous_symbol, (CNU, type(None))
):
if next_symbol.power != 1 and (
(previous_symbol is None) or (previous_symbol.power != 1)
):
result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output
if big:
attr_name = "big_"
if traditional:
attr_name += "t"
else:
attr_name += "s"
else:
if traditional:
attr_name = "traditional"
else:
attr_name = "simplified"
result = "".join([getattr(s, attr_name) for s in result_symbols])
# if not use_zeros:
# result = result.strip(getattr(system.digits[0], attr_name))
if alt_zero:
result = result.replace(
getattr(system.digits[0], attr_name), system.digits[0].alt_s
)
if alt_one:
result = result.replace(
getattr(system.digits[1], attr_name), system.digits[1].alt_s
)
for i, p in enumerate(POINT):
if result.startswith(p):
return CHINESE_DIGIS[0] + result
# ^10, 11, .., 19
if (
len(result) >= 2
and result[1]
in [
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
]
and result[0]
in [
CHINESE_DIGIS[1],
BIG_CHINESE_DIGIS_SIMPLIFIED[1],
BIG_CHINESE_DIGIS_TRADITIONAL[1],
]
):
result = result[1:]
return result
if __name__ == "__main__":
# 测试程序
all_chinese_number_string = (
CHINESE_DIGIS
+ BIG_CHINESE_DIGIS_SIMPLIFIED
+ BIG_CHINESE_DIGIS_TRADITIONAL
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ ZERO_ALT
+ ONE_ALT
+ "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT)
)
print("num:", chn2num("一万零四百零三点八零五"))
print("num:", chn2num("一亿六点三"))
print("num:", chn2num("一亿零六点三"))
print("num:", chn2num("两千零一亿六点三"))
# print('num:', chn2num('一零零八六'))
print("txt:", num2chn("10260.03", alt_zero=True))
print("txt:", num2chn("20037.090", numbering_type="low", traditional=True))
print("txt:", num2chn("100860001.77", numbering_type="high", big=True))
print(
"txt:",
num2chn(
"059523810880",
alt_one=True,
alt_two=False,
use_lzeros=True,
use_rzeros=True,
use_units=False,
),
)
print(all_chinese_number_string)