Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import requests | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
header = """ | |
import psycopg2 | |
conn = psycopg2.connect("CONN") | |
cur = conn.cursor() | |
def set_customer_name(id: int, new_name: str): | |
# PROMPT | |
cur.execute("UPDATE customer SET name | |
""" | |
modelPath = { | |
"GPT2-Medium": "gpt2-medium", | |
"CodeParrot-mini": "codeparrot/codeparrot-small", | |
"CodeGen-350-Mono": "Salesforce/codegen-350M-mono", | |
"GPT-J": "EleutherAI/gpt-j-6B", | |
"CodeParrot": "codeparrot/codeparrot", | |
"CodeGen-2B-Mono": "Salesforce/codegen-2B-mono", | |
} | |
def generation(tokenizer, model, content): | |
input_ids = tokenizer.encode(content, return_tensors='pt') | |
num_beams = 2 if decoder == 'Beam' else None | |
typical_p = 0.8 if decoder == 'Typical' else None | |
do_sample = (decoder in ['Beam', 'Typical', 'Sample']) | |
typ_output = model.generate( | |
input_ids, | |
max_length=120, | |
num_beams=num_beams, | |
early_stopping=True, | |
do_sample=do_sample, | |
typical_p=typical_p, | |
repetition_penalty=4.0, | |
) | |
txt = tokenizer.decode(typ_output[0], skip_special_tokens=True) | |
return txt | |
def code_from_prompts(prompt, model, type_hints): | |
tokenizer = AutoTokenizer.from_pretrained(modelPath[model]) | |
model = AutoModelForCausalLM.from_pretrained(modelPath[model]) | |
code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt) | |
# if type_hints: | |
results = [ | |
generation(tokenizer, model, code), | |
0.5, | |
] | |
del tokenizer | |
del model | |
return results | |
iface = gr.Interface( | |
fn=code_from_prompts, | |
inputs=[ | |
gr.inputs.Textbox(label="Insert comment"), | |
gr.inputs.Radio(list(modelPath.keys()), label="Code Model"), | |
gr.inputs.Checkbox(label="Include type hints") | |
], | |
outputs=[ | |
gr.outputs.Textbox(label="Generated code"), | |
gr.outputs.Textbox(label="Probability"), | |
], | |
description="What does a code-generation model assume about the name in the header comment?", | |
) | |
iface.launch() | |