File size: 3,508 Bytes
3d9b66a
 
eb4b465
3d9b66a
eb4b465
 
3d9b66a
 
 
 
 
 
 
8e9c357
eb4b465
3d9b66a
 
 
 
 
 
 
 
 
 
 
 
 
 
2636ace
3d9b66a
 
 
 
eb4b465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d9b66a
8e9c357
3d9b66a
eb4b465
3d9b66a
 
 
2636ace
 
8e9c357
eb4b465
3d9b66a
8e9c357
 
 
 
 
 
 
d4533bc
8e9c357
 
 
 
 
d4533bc
8e9c357
d4533bc
3d9b66a
 
 
 
 
 
 
 
 
8e9c357
 
 
 
 
 
3d9b66a
 
d4533bc
 
3d9b66a
2636ace
3d9b66a
eb4b465
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import torch
import ecco
import requests
from transformers import AutoTokenizer
from torch.nn import functional as F

header = """
import psycopg2

conn = psycopg2.connect("CONN")
cur = conn.cursor()

MIDDLE
def rename_customer(id, newName):
  # PROMPT
  cur.execute("UPDATE customer SET name
"""

modelPath = {
    "GPT2-Medium": "gpt2-medium",
    "CodeParrot-mini": "codeparrot/codeparrot-small",
    "CodeGen-350-Mono": "Salesforce/codegen-350M-mono",
    "GPT-J": "EleutherAI/gpt-j-6B",
    "CodeParrot": "codeparrot/codeparrot",
    "CodeGen-2B-Mono": "Salesforce/codegen-2B-mono",
}

def generation(tokenizer, model, content):
    decoder = 'Standard'
    num_beams = 2 if decoder == 'Beam' else None
    typical_p = 0.8 if decoder == 'Typical' else None
    do_sample = (decoder in ['Beam', 'Typical', 'Sample'])

    seek_token_ids = tokenizer.encode('= " +')[1:]

    full_output = lm.generate(code, generate=10, do_sample=False)

    def next_words(code, position, seek_token_ids):
        op_model = lm.generate(code, generate=1, do_sample=False)
        hidden_states = op_model.hidden_states
        layer_no = len(hidden_states) - 1
        h = hidden_states[-1]
        hidden_state = h[position - 1]
        logits = op_model.lm_head(op_model.to(hidden_state))
        softmax = F.softmax(logits, dim=-1)
        my_token_prob = softmax[seek_token_ids[0]]

        if len(seek_token_ids) > 1:
            newprompt = code + tokenizer.decode(seek_token_ids[0])
            return my_token_prob * next_words(newprompt, position + 1, seek_token_ids[1:])
        return my_token_prob

    prob = next_words(content, len(tokenizer(content)['input_ids']), seek_token_ids)
    return ["".join(full_output.tokens), prob]

def code_from_prompts(prompt, model, type_hints, pre_content):
    tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
    model = ecco.from_pretrained(modelPath[model])

    code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt)

    if type_hints:
        code = code.replace('id,', 'id: int,')
        code = code.replace('id)', 'id: int)')
        code = code.replace('newName)', 'newName: str) -> None')

    if pre_content == 'None':
        code = code.replace('MIDDLE\n', '')
    elif 'Concatenation' in pre_content:
        code = code.replace('MIDDLE', """
def get_customer(id):
    cur.execute('SELECT * FROM customers WHERE id = ' + str(id))
    return cur.fetchall()
""".strip() + "\n")
    elif 'composition' in pre_content:
        code = code.replace('MIDDLE', """
def get_customer(id):
    cur.execute('SELECT * FROM customers WHERE id = %s', str(id))
    return cur.fetchall()
""".strip() + "\n")

    results = generation(tokenizer, model, code)
    del tokenizer
    del model
    return results

iface = gr.Interface(
    fn=code_from_prompts,
	inputs=[
        gr.inputs.Textbox(label="Insert comment"),
        gr.inputs.Radio(list(modelPath.keys()), label="Code Model"),
        gr.inputs.Checkbox(label="Include type hints"),
        gr.inputs.Radio([
            "None",
            "Proper composition: Include function 'WHERE id = %s'",
            "Concatenation: Include a function with 'WHERE id = ' + id",
        ], label="Pre-written function")
    ],
	outputs=[
        gr.outputs.Textbox(label="Most probable code"),
        gr.outputs.Textbox(label="Probability of concat"),
    ],
	description="Prompt the code model to write a SQL query with string concatenation.",
)
iface.launch()