georeactor commited on
Commit
3d9b66a
·
1 Parent(s): b23c77a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import requests
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ header = """
7
+ import psycopg2
8
+
9
+ conn = psycopg2.connect("CONN")
10
+ cur = conn.cursor()
11
+
12
+ def set_customer_name(id: int, new_name: str):
13
+ # PROMPT
14
+ cur.execute("UPDATE customer SET name
15
+ """
16
+
17
+ modelPath = {
18
+ "GPT2-Medium": "gpt2-medium",
19
+ "CodeParrot-mini": "codeparrot/codeparrot-small",
20
+ "CodeGen-350-Mono": "Salesforce/codegen-350M-mono",
21
+ "GPT-J": "EleutherAI/gpt-j-6B",
22
+ "CodeParrot": "codeparrot/codeparrot",
23
+ "CodeGen-2B-Mono": "Salesforce/codegen-2B-mono",
24
+ }
25
+
26
+ def generation(tokenizer, model, content):
27
+ input_ids = tokenizer.encode(content, return_tensors='pt')
28
+ num_beams = 2 if decoder == 'Beam' else None
29
+ typical_p = 0.8 if decoder == 'Typical' else None
30
+ do_sample = (decoder in ['Beam', 'Typical', 'Sample'])
31
+
32
+ typ_output = model.generate(
33
+ input_ids,
34
+ max_length=120,
35
+ num_beams=num_beams,
36
+ early_stopping=True,
37
+ do_sample=do_sample,
38
+ typical_p=typical_p,
39
+ repetition_penalty=4.0,
40
+ )
41
+ txt = tokenizer.decode(typ_output[0], skip_special_tokens=True)
42
+ return txt
43
+
44
+ def code_from_prompts(prompt, model, type_hints):
45
+ tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
46
+ model = AutoModelForCausalLM.from_pretrained(modelPath[model])
47
+
48
+ code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt)
49
+
50
+ # if type_hints:
51
+
52
+ results = [
53
+ generation(tokenizer, model, code),
54
+ 0.5,
55
+ ]
56
+ del tokenizer
57
+ del model
58
+ return results
59
+
60
+ iface = gr.Interface(
61
+ fn=code_from_prompts,
62
+ inputs=[
63
+ gr.inputs.Textbox(label="Insert comment"),
64
+ gr.inputs.Radio(list(modelPath.keys()), label="Code Model"),
65
+ gr.inputs.Checkbox(label="Include type hints")
66
+ ],
67
+ outputs=[
68
+ gr.outputs.Textbox(label="Generated code"),
69
+ gr.outputs.Textbox(label="Probability"),
70
+ ],
71
+ description="What does a code-generation model assume about the name in the header comment?",
72
+ )
73
+ iface.launch()