Spaces:
Runtime error
Runtime error
File size: 3,508 Bytes
3d9b66a eb4b465 3d9b66a eb4b465 3d9b66a 8e9c357 eb4b465 3d9b66a 2636ace 3d9b66a eb4b465 3d9b66a 8e9c357 3d9b66a eb4b465 3d9b66a 2636ace 8e9c357 eb4b465 3d9b66a 8e9c357 d4533bc 8e9c357 d4533bc 8e9c357 d4533bc 3d9b66a 8e9c357 3d9b66a d4533bc 3d9b66a 2636ace 3d9b66a eb4b465 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
import torch
import ecco
import requests
from transformers import AutoTokenizer
from torch.nn import functional as F
header = """
import psycopg2
conn = psycopg2.connect("CONN")
cur = conn.cursor()
MIDDLE
def rename_customer(id, newName):
# PROMPT
cur.execute("UPDATE customer SET name
"""
modelPath = {
"GPT2-Medium": "gpt2-medium",
"CodeParrot-mini": "codeparrot/codeparrot-small",
"CodeGen-350-Mono": "Salesforce/codegen-350M-mono",
"GPT-J": "EleutherAI/gpt-j-6B",
"CodeParrot": "codeparrot/codeparrot",
"CodeGen-2B-Mono": "Salesforce/codegen-2B-mono",
}
def generation(tokenizer, model, content):
decoder = 'Standard'
num_beams = 2 if decoder == 'Beam' else None
typical_p = 0.8 if decoder == 'Typical' else None
do_sample = (decoder in ['Beam', 'Typical', 'Sample'])
seek_token_ids = tokenizer.encode('= " +')[1:]
full_output = lm.generate(code, generate=10, do_sample=False)
def next_words(code, position, seek_token_ids):
op_model = lm.generate(code, generate=1, do_sample=False)
hidden_states = op_model.hidden_states
layer_no = len(hidden_states) - 1
h = hidden_states[-1]
hidden_state = h[position - 1]
logits = op_model.lm_head(op_model.to(hidden_state))
softmax = F.softmax(logits, dim=-1)
my_token_prob = softmax[seek_token_ids[0]]
if len(seek_token_ids) > 1:
newprompt = code + tokenizer.decode(seek_token_ids[0])
return my_token_prob * next_words(newprompt, position + 1, seek_token_ids[1:])
return my_token_prob
prob = next_words(content, len(tokenizer(content)['input_ids']), seek_token_ids)
return ["".join(full_output.tokens), prob]
def code_from_prompts(prompt, model, type_hints, pre_content):
tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
model = ecco.from_pretrained(modelPath[model])
code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt)
if type_hints:
code = code.replace('id,', 'id: int,')
code = code.replace('id)', 'id: int)')
code = code.replace('newName)', 'newName: str) -> None')
if pre_content == 'None':
code = code.replace('MIDDLE\n', '')
elif 'Concatenation' in pre_content:
code = code.replace('MIDDLE', """
def get_customer(id):
cur.execute('SELECT * FROM customers WHERE id = ' + str(id))
return cur.fetchall()
""".strip() + "\n")
elif 'composition' in pre_content:
code = code.replace('MIDDLE', """
def get_customer(id):
cur.execute('SELECT * FROM customers WHERE id = %s', str(id))
return cur.fetchall()
""".strip() + "\n")
results = generation(tokenizer, model, code)
del tokenizer
del model
return results
iface = gr.Interface(
fn=code_from_prompts,
inputs=[
gr.inputs.Textbox(label="Insert comment"),
gr.inputs.Radio(list(modelPath.keys()), label="Code Model"),
gr.inputs.Checkbox(label="Include type hints"),
gr.inputs.Radio([
"None",
"Proper composition: Include function 'WHERE id = %s'",
"Concatenation: Include a function with 'WHERE id = ' + id",
], label="Pre-written function")
],
outputs=[
gr.outputs.Textbox(label="Most probable code"),
gr.outputs.Textbox(label="Probability of concat"),
],
description="Prompt the code model to write a SQL query with string concatenation.",
)
iface.launch()
|