|
import argparse |
|
from idiomify import tensors as T |
|
from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms |
|
from transformers import BertTokenizer |
|
from termcolor import colored |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model", type=str, |
|
default="alpha") |
|
parser.add_argument("--ver", type=str, |
|
default="eng2eng") |
|
parser.add_argument("--sent", type=str, |
|
default="to avoid getting to the point") |
|
args = parser.parse_args() |
|
config = fetch_config()[args.model][args.ver] |
|
config.update(vars(args)) |
|
idioms = fetch_idioms(config['idioms_ver']) |
|
rd = fetch_rd(config['model'], config['ver']) |
|
rd.eval() |
|
tokenizer = BertTokenizer.from_pretrained(config['bert']) |
|
X = T.inputs([("", config['sent'])], tokenizer, config['k']) |
|
probs = rd.P_wisdom(X).squeeze().tolist() |
|
wisdom2prob = [ |
|
(wisdom, prob) |
|
for wisdom, prob in zip(idioms, probs) |
|
] |
|
|
|
res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True)) |
|
print(f"query: {colored(text=config['sent'], color='blue')}") |
|
for idx, (idiom, prob) in enumerate(res): |
|
print(idx, idiom, prob) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|