Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) | |
# 2023 Binbin Zhang ([email protected]) | |
# 2023 Kaixun Huang | |
# 2023 Chengdong Liang ([email protected]) | |
# See ../LICENSE for clarification regarding multiple authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from wenet.text.tokenize_utils import tokenize_by_bpe_model | |
from typing import Dict, List, Tuple | |
from collections import deque | |
def tokenize(context_list_path, symbol_table, bpe_model=None): | |
""" Read biasing list from the biasing list address, tokenize and convert it | |
into token id | |
""" | |
if bpe_model is not None: | |
import sentencepiece as spm | |
sp = spm.SentencePieceProcessor() | |
sp.load(bpe_model) | |
else: | |
sp = None | |
with open(context_list_path, "r") as fin: | |
context_txts = fin.readlines() | |
context_list = [] | |
for context_txt in context_txts: | |
context_txt = context_txt.strip() | |
labels = [] | |
tokens = [] | |
if bpe_model is not None: | |
tokens = tokenize_by_bpe_model(sp, context_txt) | |
else: | |
for ch in context_txt: | |
if ch == ' ': | |
ch = "▁" | |
tokens.append(ch) | |
for ch in tokens: | |
if ch in symbol_table: | |
labels.append(symbol_table[ch]) | |
elif '<unk>' in symbol_table: | |
labels.append(symbol_table['<unk>']) | |
context_list.append(labels) | |
return context_list | |
class ContextState: | |
"""The state in ContextGraph""" | |
def __init__( | |
self, | |
id: int, | |
token: int, | |
token_score: float, | |
node_score: float, | |
output_score: float, | |
is_end: bool, | |
): | |
"""Create a ContextState. | |
Args: | |
id: | |
The node id, only for visualization now. A node is in [0, graph.num_nodes). | |
The id of the root node is always 0. | |
token: | |
The token id. | |
token_score: | |
The bonus for each token during decoding, which will hopefully | |
boost the token up to survive beam search. | |
node_score: | |
The accumulated bonus from root of graph to current node, it will be | |
used to calculate the score for fail arc. | |
output_score: | |
The total scores of matched phrases, sum of the node_score of all | |
the output node for current node. | |
is_end: | |
True if current token is the end of a context. | |
""" | |
self.id = id | |
self.token = token | |
self.token_score = token_score | |
self.node_score = node_score | |
self.output_score = output_score | |
self.is_end = is_end | |
self.next = {} | |
self.fail = None | |
self.output = None | |
class ContextGraph: | |
"""The ContextGraph is modified from Aho-Corasick which is mainly | |
a Trie with a fail arc for each node. | |
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details | |
of Aho-Corasick algorithm. | |
A ContextGraph contains some words / phrases that we expect to boost their | |
scores during decoding. If the substring of a decoded sequence matches the word / phrase # noqa | |
in the ContextGraph, we will give the decoded sequence a bonus to make it survive | |
beam search. | |
""" | |
def __init__(self, | |
context_list_path: str, | |
symbol_table: Dict[str, int], | |
bpe_model: str = None, | |
context_score: float = 6.0): | |
"""Initialize a ContextGraph with the given ``context_score``. | |
A root node will be created (**NOTE:** the token of root is hardcoded to -1). | |
Args: | |
context_score: | |
The bonus score for each token(note: NOT for each word/phrase, it means longer # noqa | |
word/phrase will have larger bonus score, they have to be matched though). | |
""" | |
self.context_score = context_score | |
self.context_list = tokenize(context_list_path, symbol_table, | |
bpe_model) | |
self.num_nodes = 0 | |
self.root = ContextState( | |
id=self.num_nodes, | |
token=-1, | |
token_score=0, | |
node_score=0, | |
output_score=0, | |
is_end=False, | |
) | |
self.root.fail = self.root | |
self.build_graph(self.context_list) | |
def build_graph(self, token_ids: List[List[int]]): | |
"""Build the ContextGraph from a list of token list. | |
It first build a trie from the given token lists, then fill the fail arc | |
for each trie node. | |
See https://en.wikipedia.org/wiki/Trie for how to build a trie. | |
Args: | |
token_ids: | |
The given token lists to build the ContextGraph, it is a list of token list, | |
each token list contains the token ids for a word/phrase. The token id | |
could be an id of a char (modeling with single Chinese char) or an id | |
of a BPE (modeling with BPEs). | |
""" | |
for tokens in token_ids: | |
node = self.root | |
for i, token in enumerate(tokens): | |
if token not in node.next: | |
self.num_nodes += 1 | |
is_end = i == len(tokens) - 1 | |
node_score = node.node_score + self.context_score | |
node.next[token] = ContextState( | |
id=self.num_nodes, | |
token=token, | |
token_score=self.context_score, | |
node_score=node_score, | |
output_score=node_score if is_end else 0, | |
is_end=is_end, | |
) | |
node = node.next[token] | |
self._fill_fail_output() # AC | |
def _fill_fail_output(self): | |
"""This function fills the fail arc for each trie node, it can be computed | |
in linear time by performing a breadth-first search starting from the root. | |
See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the | |
details of the algorithm. | |
""" | |
queue = deque() | |
for token, node in self.root.next.items(): | |
node.fail = self.root | |
queue.append(node) | |
while queue: | |
current_node = queue.popleft() | |
for token, node in current_node.next.items(): | |
fail = current_node.fail | |
if token in fail.next: | |
fail = fail.next[token] | |
else: | |
fail = fail.fail | |
while token not in fail.next: | |
fail = fail.fail | |
if fail.token == -1: # root | |
break | |
if token in fail.next: | |
fail = fail.next[token] | |
node.fail = fail | |
# fill the output arc | |
output = node.fail | |
while not output.is_end: | |
output = output.fail | |
if output.token == -1: # root | |
output = None | |
break | |
node.output = output | |
node.output_score += 0 if output is None else output.output_score | |
queue.append(node) | |
def forward_one_step(self, state: ContextState, | |
token: int) -> Tuple[float, ContextState]: | |
"""Search the graph with given state and token. | |
Args: | |
state: | |
The given token containing trie node to start. | |
token: | |
The given token. | |
Returns: | |
Return a tuple of score and next state. | |
""" | |
node = None | |
score = 0 | |
# token matched | |
if token in state.next: | |
node = state.next[token] | |
score = node.token_score | |
else: | |
# token not matched | |
# We will trace along the fail arc until it matches the token or reaching | |
# root of the graph. | |
node = state.fail | |
while token not in node.next: | |
node = node.fail | |
if node.token == -1: # root | |
break | |
if token in node.next: | |
node = node.next[token] | |
# The score of the fail path | |
score = node.node_score - state.node_score | |
assert node is not None | |
return (score + node.output_score, node) | |
def finalize(self, state: ContextState) -> Tuple[float, ContextState]: | |
"""When reaching the end of the decoded sequence, we need to finalize | |
the matching, the purpose is to subtract the added bonus score for the | |
state that is not the end of a word/phrase. | |
Args: | |
state: | |
The given state(trie node). | |
Returns: | |
Return a tuple of score and next state. If state is the end of a word/phrase | |
the score is zero, otherwise the score is the score of a implicit fail arc | |
to root. The next state is always root. | |
""" | |
# The score of the fail arc | |
score = -state.node_score | |
return (score, self.root) | |