saujasv commited on
Commit
74f6684
·
1 Parent(s): 0dbe51e

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -0
handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from greenery import parse
3
+ from greenery.parse import NoMatch
4
+ from listener import Listener, ListenerOutput
5
+ import time
6
+ import json
7
+ import torch
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ self.listener = Listener(path, {
12
+ "do_sample": True,
13
+ "max_new_tokens": 128,
14
+ "top_p": 0.9,
15
+ "num_return_sequences": 500,
16
+ "num_beams": 1
17
+ }, device="cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ def __call__(self, data):
20
+ # get inputs
21
+ inp = data.pop("inputs", None)
22
+ spec = inp["spec"]
23
+ true_program = inp["true_program"]
24
+
25
+ start = time.time()
26
+ outputs = self.listener.synthesize([[(s["string"], s["label"]) for s in spec]], return_scores=True)
27
+ consistent_program_scores = [outputs.decoded_scores[0][i] for i in outputs.idx[0]]
28
+ consistent_programs = [outputs.decoded[0][i] for i in outputs.idx[0]]
29
+ sorted_programs = sorted(set(zip(consistent_program_scores, consistent_programs)), reverse=True, key=lambda x: x[0])
30
+ end = time.time()
31
+
32
+ return {
33
+ "guess": sorted_programs[0][1],
34
+ "top_1_success": parse(sorted_programs[0][1]).equivalent(parse(true_program)),
35
+ "top_1_score": sorted_programs[0][0],
36
+ "top_5_success": any([parse(p).equivalent(parse(true_program)) for _, p in sorted_programs[:5]]),
37
+ "time": end - start
38
+ }