File size: 1,457 Bytes
e9d1a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# we disable them for now.
# import argparse
# from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms
# from transformers import BertTokenizer
# from termcolor import colored
#
#
# def main():
#         parser = argparse.ArgumentParser()
#         parser.add_argument("--model", type=str,
#                             default="alpha")
#         parser.add_argument("--ver", type=str,
#                             default="eng2eng")
#         parser.add_argument("--sent", type=str,
#                             default="to avoid getting to the point")
#         args = parser.parse_args()
#         config = fetch_config()[args.model][args.ver]
#         config.update(vars(args))
#         idioms = fetch_idioms(config['idioms_ver'])
#         rd = fetch_rd(config['model'], config['ver'])
#         rd.eval()
#         tokenizer = BertTokenizer.from_pretrained(config['bert'])
#         X = T.inputs([("", config['sent'])], tokenizer, config['k'])
#         probs = rd.P_wisdom(X).squeeze().tolist()
#         wisdom2prob = [
#                 (wisdom, prob)
#                 for wisdom, prob in zip(idioms, probs)
#         ]
#         # sort and append
#         res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
#         print(f"query: {colored(text=config['sent'], color='blue')}")
#         for idx, (idiom, prob) in enumerate(res):
#             print(idx, idiom, prob)
#
#
# if __name__ == '__main__':
#     main()