File size: 2,952 Bytes
87c604f
2869f1d
678feb3
eb939c7
2869f1d
 
eb939c7
 
678feb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24ce2c0
678feb3
 
 
 
 
6b9bc55
2869f1d
678feb3
2869f1d
678feb3
 
 
 
 
 
 
 
966ca93
678feb3
 
 
 
 
 
 
 
 
87c604f
2869f1d
678feb3
 
87c604f
678feb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c604f
678feb3
87c604f
 
678feb3
 
 
 
 
 
 
 
 
 
 
 
87c604f
 
678feb3
 
 
 
 
 
 
 
87c604f
678feb3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Any
import gradio as gr
import itertools
import torch
from listener import Listener

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# pragmatic_listener = Listener(
#     "pragmatic-programs/pragmatic-ft-listener",
#     {
#         "do_sample": True,
#         "num_return_sequences": 100,
#         "num_beams": 1,
#         "temperature": 1,
#         "top_p": 0.9,
#         "max_new_tokens": 128,
#     },
# )
listener = Listener(
    "pragmatic-programs/listener-suffix-idx-300k",
    {
        "do_sample": True,
        "num_return_sequences": 100,
        "num_beams": 1,
        "temperature": 1,
        "top_p": 0.9,
        "max_new_tokens": 128,
    },
)

N_EXAMPLES = 3


def synthesize(*inps):
    strings = [inps[2 * i] for i in range(N_EXAMPLES) if len(inps[2 * i]) > 0]
    labels = [
        "+" if inps[2 * i + 1] == "match" else "-"
        for i in range(N_EXAMPLES)
        if inps[2 * i + 1] is not None
    ]
    spec = [list(zip(strings, labels))]

    # if len(context.strip()) == 0:
    #     return "Empty specification", "Empty specification"

    # spec = [[[s[:-1], s[-1]] for s in context.strip().split(" ")]]

    if len(spec) == 0:
        return "Empty specification"
    if not all([len(s) > 0 and l in ["+", "-"] for s, l in spec[0]]):
        return "Invalid specification", "Invalid specification"

    # pragmatic_outputs = pragmatic_listener.synthesize(spec).programs
    # literal_outputs = literal_listener.synthesize(spec).programs

    # if len(pragmatic_outputs[0]) > 0:
    #     pragmatic_program = pragmatic_outputs[0][0]
    # else:
    #     pragmatic_program = "No program found"

    # if len(literal_outputs[0]) > 0:
    #     literal_program = literal_outputs[0][0]
    # else:
    #     literal_program = "No program found"

    # return pragmatic_program, literal_program

    outputs = listener.synthesize(spec).programs

    if len(outputs[0]) > 0:
        return outputs[0][0]
    else:
        return "No program found"


input_fields = [
    (
        gr.Textbox(
            lines=1,
            label=f"Example {i + 1}",
            # info="Enter a list of examples, separated by spaces. Each example is the string followed by a + or - indicating whether it should be accepted or rejected by the synthesized regex.",
            container=True,
        ),
        gr.Radio(["match", "not match"], container=False, label="Label"),
    )
    for i in range(N_EXAMPLES)
]

iface = gr.Interface(
    fn=synthesize,
    inputs=list(itertools.chain.from_iterable(input_fields)),
    outputs=gr.Textbox(lines=1, label="Synthesizer output"),
    # [
    #     gr.Textbox(lines=1, label="Pragmatic model"),
    #     gr.Textbox(lines=1, label="Literal model"),
    # ],
    # examples=["ab+ aabb+ abb+ abab-", "b0b+ aa0000bb+"],
    title="Synthesizing regular expressions from examples",
    theme=gr.themes.Soft(primary_hue="blue"),
)
iface.launch()