# Copyright 2022 The OFA-Sys Team. | |
# All rights reserved. | |
# This source code is licensed under the Apache 2.0 license | |
# found in the LICENSE file in the root directory. | |
from collections import defaultdict | |
class TreeNode(): | |
def __init__(self): | |
self.child = defaultdict(TreeNode) | |
class Trie: | |
def __init__(self, eos): | |
self.root = TreeNode() | |
self.eos = eos | |
def insert(self, word): | |
cur = self.root | |
for c in word: | |
cur = cur.child[c] | |
def get_next_layer(self, word): | |
cur = self.root | |
for c in word: | |
cur = cur.child.get(c) | |
if cur is None: | |
return [self.eos] | |
return list(cur.child.keys()) | |