pragmatic-ft-listener / handler.py
saujasv's picture
Create handler.py
74f6684
raw
history blame
1.48 kB
from transformers import pipeline
from greenery import parse
from greenery.parse import NoMatch
from listener import Listener, ListenerOutput
import time
import json
import torch
class EndpointHandler:
def __init__(self, path=""):
self.listener = Listener(path, {
"do_sample": True,
"max_new_tokens": 128,
"top_p": 0.9,
"num_return_sequences": 500,
"num_beams": 1
}, device="cuda" if torch.cuda.is_available() else "cpu")
def __call__(self, data):
# get inputs
inp = data.pop("inputs", None)
spec = inp["spec"]
true_program = inp["true_program"]
start = time.time()
outputs = self.listener.synthesize([[(s["string"], s["label"]) for s in spec]], return_scores=True)
consistent_program_scores = [outputs.decoded_scores[0][i] for i in outputs.idx[0]]
consistent_programs = [outputs.decoded[0][i] for i in outputs.idx[0]]
sorted_programs = sorted(set(zip(consistent_program_scores, consistent_programs)), reverse=True, key=lambda x: x[0])
end = time.time()
return {
"guess": sorted_programs[0][1],
"top_1_success": parse(sorted_programs[0][1]).equivalent(parse(true_program)),
"top_1_score": sorted_programs[0][0],
"top_5_success": any([parse(p).equivalent(parse(true_program)) for _, p in sorted_programs[:5]]),
"time": end - start
}