Spaces:
Runtime error
Runtime error
File size: 3,677 Bytes
3d9b66a eb4b465 3d9b66a eb4b465 3d9b66a 8e9c357 dac8d18 3d9b66a e1e1d6a 3d9b66a 2636ace 3d9b66a ac62d88 eb4b465 e1e1d6a eb4b465 9b11bf4 eb4b465 ac62d88 e01d0a4 3d9b66a 8e9c357 3d9b66a eb4b465 3d9b66a 2636ace 8e9c357 eb4b465 3d9b66a 8e9c357 e01d0a4 d4533bc 8e9c357 e01d0a4 d4533bc 8e9c357 d4533bc 3d9b66a 8e9c357 3d9b66a d4533bc 3d9b66a 2636ace 3d9b66a eb4b465 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import gradio as gr
import torch
import ecco
import requests
from transformers import AutoTokenizer
from torch.nn import functional as F
header = """
import psycopg2
conn = psycopg2.connect("CONN")
cur = conn.cursor()
MIDDLE
def rename_customer(id, newName):\n\t# PROMPT\n\tcur.execute("UPDATE customer SET name =
"""
modelPath = {
"GPT2-Medium": "gpt2-medium",
"CodeParrot-mini": "codeparrot/codeparrot-small",
"CodeGen-350-Mono": "Salesforce/codegen-350M-mono",
"GPT-Neo-1.3B": "EleutherAI/gpt-neo-1.3B",
"CodeParrot": "codeparrot/codeparrot",
"CodeGen-2B-Mono": "Salesforce/codegen-2B-mono",
}
def generation(tokenizer, model, content):
decoder = 'Standard'
num_beams = 2 if decoder == 'Beam' else None
typical_p = 0.8 if decoder == 'Typical' else None
do_sample = (decoder in ['Beam', 'Typical', 'Sample'])
seek_token_ids = [
tokenizer.encode('= \'" +')[1:],
tokenizer.encode('= " +')[1:],
]
full_output = model.generate(content, generate=6, do_sample=False)
def next_words(code, position, seek_token_ids):
op_model = model.generate(code, generate=1, do_sample=False)
hidden_states = op_model.hidden_states
layer_no = len(hidden_states) - 1
h = hidden_states[-1]
hidden_state = h[position - 1]
logits = op_model.lm_head(op_model.to(hidden_state))
softmax = F.softmax(logits, dim=-1)
my_token_prob = softmax[seek_token_ids[0]]
if len(seek_token_ids) > 1:
newprompt = code + tokenizer.decode(seek_token_ids[0])
return my_token_prob * next_words(newprompt, position + 1, seek_token_ids[1:])
return my_token_prob
prob = 0
for opt in seek_token_ids:
prob += next_words(content, len(tokenizer(content)['input_ids']), opt)
return ["".join(full_output.tokens), str(prob.item() * 100) + '% chance of risky concatenation']
def code_from_prompts(prompt, model, type_hints, pre_content):
tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
model = ecco.from_pretrained(modelPath[model])
code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt)
if type_hints:
code = code.replace('id,', 'id: int,')
code = code.replace('id)', 'id: int)')
code = code.replace('newName)', 'newName: str) -> None')
if pre_content == 'None':
code = code.replace('MIDDLE\n', '')
elif 'Concatenation' in pre_content:
code = code.replace('MIDDLE', """
def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = ' + str(id))\n\treturn cur.fetchall()
""".strip() + "\n")
elif 'composition' in pre_content:
code = code.replace('MIDDLE', """
def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = %s', str(id))\n\treturn cur.fetchall()
""".strip() + "\n")
results = generation(tokenizer, model, code)
del tokenizer
del model
return results
iface = gr.Interface(
fn=code_from_prompts,
inputs=[
gr.inputs.Textbox(label="Insert comment"),
gr.inputs.Radio(list(modelPath.keys()), label="Code Model"),
gr.inputs.Checkbox(label="Include type hints"),
gr.inputs.Radio([
"None",
"Proper composition: Include function 'WHERE id = %s'",
"Concatenation: Include a function with 'WHERE id = ' + id",
], label="Pre-written function")
],
outputs=[
gr.outputs.Textbox(label="Most probable code"),
gr.outputs.Textbox(label="Probability of concat"),
],
description="Prompt the code model to write a SQL query with string concatenation.",
)
iface.launch()
|