File size: 1,395 Bytes
539e83f ec156ad 539e83f 1ef8679 539e83f ec156ad 539e83f ec156ad 539e83f 1ef8679 539e83f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import argparse
from idiomify import tensors as T
from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms
from transformers import BertTokenizer
from termcolor import colored
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str,
default="alpha")
parser.add_argument("--ver", type=str,
default="eng2eng")
parser.add_argument("--sent", type=str,
default="to avoid getting to the point")
args = parser.parse_args()
config = fetch_config()[args.model][args.ver]
config.update(vars(args))
idioms = fetch_idioms(config['idioms_ver'])
rd = fetch_rd(config['model'], config['ver'])
rd.eval()
tokenizer = BertTokenizer.from_pretrained(config['bert'])
X = T.inputs([("", config['sent'])], tokenizer, config['k'])
probs = rd.P_wisdom(X).squeeze().tolist()
wisdom2prob = [
(wisdom, prob)
for wisdom, prob in zip(idioms, probs)
]
# sort and append
res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
print(f"query: {colored(text=config['sent'], color='blue')}")
for idx, (idiom, prob) in enumerate(res):
print(idx, idiom, prob)
if __name__ == '__main__':
main()
|