File size: 1,457 Bytes
e9d1a5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# we disable them for now.
# import argparse
# from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms
# from transformers import BertTokenizer
# from termcolor import colored
#
#
# def main():
# parser = argparse.ArgumentParser()
# parser.add_argument("--model", type=str,
# default="alpha")
# parser.add_argument("--ver", type=str,
# default="eng2eng")
# parser.add_argument("--sent", type=str,
# default="to avoid getting to the point")
# args = parser.parse_args()
# config = fetch_config()[args.model][args.ver]
# config.update(vars(args))
# idioms = fetch_idioms(config['idioms_ver'])
# rd = fetch_rd(config['model'], config['ver'])
# rd.eval()
# tokenizer = BertTokenizer.from_pretrained(config['bert'])
# X = T.inputs([("", config['sent'])], tokenizer, config['k'])
# probs = rd.P_wisdom(X).squeeze().tolist()
# wisdom2prob = [
# (wisdom, prob)
# for wisdom, prob in zip(idioms, probs)
# ]
# # sort and append
# res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
# print(f"query: {colored(text=config['sent'], color='blue')}")
# for idx, (idiom, prob) in enumerate(res):
# print(idx, idiom, prob)
#
#
# if __name__ == '__main__':
# main()
|