File size: 2,950 Bytes
87c604f
2869f1d
678feb3
eb939c7
2869f1d
 
eb939c7
 
678feb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb939c7
678feb3
 
 
 
 
6b9bc55
2869f1d
678feb3
2869f1d
678feb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c604f
2869f1d
678feb3
 
87c604f
678feb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c604f
678feb3
87c604f
 
678feb3
 
 
 
 
 
 
 
 
 
 
 
87c604f
 
678feb3
 
 
 
 
 
 
 
87c604f
678feb3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Any
import gradio as gr
import itertools
import torch
from listener import Listener

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# pragmatic_listener = Listener(
#     "pragmatic-programs/pragmatic-ft-listener",
#     {
#         "do_sample": True,
#         "num_return_sequences": 100,
#         "num_beams": 1,
#         "temperature": 1,
#         "top_p": 0.9,
#         "max_new_tokens": 128,
#     },
# )
listener = Listener(
    "pragmatic-programs/listener-suffix-idx-300k",
    {
        "do_sample": True,
        "num_return_sequences": 500,
        "num_beams": 1,
        "temperature": 1,
        "top_p": 0.9,
        "max_new_tokens": 128,
    },
)

N_EXAMPLES = 3


def synthesize(*inps):
    strings = [inps[2 * i] for i in range(N_EXAMPLES) if len(inps[2 * i]) > 0]
    labels = [
        "+" if inps[2 * i + 1] == "match" else "-"
        for i in range(N_EXAMPLES)
        if inps[2 * i + 1] is not None
    ]
    spec = list(zip(strings, labels))

    # if len(context.strip()) == 0:
    #     return "Empty specification", "Empty specification"

    # spec = [[[s[:-1], s[-1]] for s in context.strip().split(" ")]]

    if len(spec) == 0:
        return "Empty specification"
    if not all([len(s) > 0 and l in ["+", "-"] for s, l in spec[0]]):
        return "Invalid specification", "Invalid specification"

    # pragmatic_outputs = pragmatic_listener.synthesize(spec).programs
    # literal_outputs = literal_listener.synthesize(spec).programs

    # if len(pragmatic_outputs[0]) > 0:
    #     pragmatic_program = pragmatic_outputs[0][0]
    # else:
    #     pragmatic_program = "No program found"

    # if len(literal_outputs[0]) > 0:
    #     literal_program = literal_outputs[0][0]
    # else:
    #     literal_program = "No program found"

    # return pragmatic_program, literal_program

    outputs = listener.synthesize(spec).programs

    if len(outputs[0]) > 0:
        return outputs[0][0]
    else:
        return "No program found"


input_fields = [
    (
        gr.Textbox(
            lines=1,
            label=f"Example {i + 1}",
            # info="Enter a list of examples, separated by spaces. Each example is the string followed by a + or - indicating whether it should be accepted or rejected by the synthesized regex.",
            container=True,
        ),
        gr.Radio(["match", "not match"], container=False, label="Label"),
    )
    for i in range(N_EXAMPLES)
]

iface = gr.Interface(
    fn=synthesize,
    inputs=list(itertools.chain.from_iterable(input_fields)),
    outputs=gr.Textbox(lines=1, label="Synthesizer output"),
    # [
    #     gr.Textbox(lines=1, label="Pragmatic model"),
    #     gr.Textbox(lines=1, label="Literal model"),
    # ],
    # examples=["ab+ aabb+ abb+ abab-", "b0b+ aa0000bb+"],
    title="Synthesizing regular expressions from examples",
    theme=gr.themes.Soft(primary_hue="blue"),
)
iface.launch()