File size: 1,395 Bytes
539e83f
 
ec156ad
539e83f
1ef8679
539e83f
 
 
 
 
 
 
 
ec156ad
539e83f
 
 
 
 
ec156ad
 
 
539e83f
 
 
 
 
 
 
1ef8679
539e83f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import argparse
from idiomify import tensors as T
from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms
from transformers import BertTokenizer
from termcolor import colored

def main():
        parser = argparse.ArgumentParser()
        parser.add_argument("--model", type=str,
                            default="alpha")
        parser.add_argument("--ver", type=str,
                            default="eng2eng")
        parser.add_argument("--sent", type=str,
                            default="to avoid getting to the point")
        args = parser.parse_args()
        config = fetch_config()[args.model][args.ver]
        config.update(vars(args))
        idioms = fetch_idioms(config['idioms_ver'])
        rd = fetch_rd(config['model'], config['ver'])
        rd.eval()
        tokenizer = BertTokenizer.from_pretrained(config['bert'])
        X = T.inputs([("", config['sent'])], tokenizer, config['k'])
        probs = rd.P_wisdom(X).squeeze().tolist()
        wisdom2prob = [
                (wisdom, prob)
                for wisdom, prob in zip(idioms, probs)
        ]
        # sort and append
        res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
        print(f"query: {colored(text=config['sent'], color='blue')}")
        for idx, (idiom, prob) in enumerate(res):
            print(idx, idiom, prob)


if __name__ == '__main__':
    main()