georeactor commited on
Commit
5bc4404
·
1 Parent(s): 5355e8b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import ecco
4
+ import requests
5
+ from transformers import AutoTokenizer
6
+ from torch.nn import functional as F
7
+
8
+ header = """
9
+ import psycopg2
10
+
11
+ conn = psycopg2.connect("CONN")
12
+ cur = conn.cursor()
13
+
14
+ MIDDLE
15
+ def rename_customer(id, newName):\n\t# PROMPT\n\tcur.execute("UPDATE customer SET name =
16
+ """
17
+
18
+ modelPath = {
19
+ # "GPT2-Medium": "gpt2-medium",
20
+ "CodeParrot-small": "codeparrot/codeparrot-small",
21
+ # "CodeGen-350-Mono": "Salesforce/codegen-350M-mono",
22
+ # "GPT-Neo-1.3B": "EleutherAI/gpt-neo-1.3B",
23
+ # "CodeParrot": "codeparrot/codeparrot",
24
+ # "CodeGen-2B-Mono": "Salesforce/codegen-2B-mono",
25
+ }
26
+
27
+ preloadModels = {}
28
+ for m in list(modelPath.keys()):
29
+ preloadModels[m] = ecco.from_pretrained(modelPath[m])
30
+
31
+ rankings = requests.get("https://code-adv.herokuapp.com/db").json()['results']
32
+
33
+ def generation(tokenizer, model, content):
34
+ decoder = 'Standard'
35
+ num_beams = 2 if decoder == 'Beam' else None
36
+ typical_p = 0.8 if decoder == 'Typical' else None
37
+ do_sample = (decoder in ['Beam', 'Typical', 'Sample'])
38
+
39
+ seek_token_ids = [
40
+ tokenizer.encode('= \'" +')[1:],
41
+ tokenizer.encode('= " +')[1:],
42
+ ]
43
+
44
+ full_output = model.generate(content, generate=6, do_sample=False)
45
+
46
+ def next_words(code, position, seek_token_ids):
47
+ op_model = model.generate(code, generate=1, do_sample=False)
48
+ hidden_states = op_model.hidden_states
49
+ layer_no = len(hidden_states) - 1
50
+ h = hidden_states[-1]
51
+ hidden_state = h[position - 1]
52
+ logits = op_model.lm_head(op_model.to(hidden_state))
53
+ softmax = F.softmax(logits, dim=-1)
54
+ my_token_prob = softmax[seek_token_ids[0]]
55
+
56
+ if len(seek_token_ids) > 1:
57
+ newprompt = code + tokenizer.decode(seek_token_ids[0])
58
+ return my_token_prob * next_words(newprompt, position + 1, seek_token_ids[1:])
59
+ return my_token_prob
60
+
61
+ prob = 0
62
+ for opt in seek_token_ids:
63
+ prob += next_words(content, len(tokenizer(content)['input_ids']), opt)
64
+ return ["".join(full_output.tokens), str(prob.item() * 100) + '% chance of risky concatenation']
65
+
66
+ def clean_comment(txt):
67
+ return txt.replace("\\", "").replace("\n", " ")
68
+
69
+ def code_from_prompts(
70
+ rankMe,
71
+ headerComment,
72
+ fnComment,
73
+ # model,
74
+ type_hints,
75
+ pre_content):
76
+ tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
77
+ # model = ecco.from_pretrained(modelPath[model])
78
+ # model = preloadModels[model]
79
+ model = preloadModels["CodeParrot-mini"]
80
+
81
+ code = clean_comment(headerComment) + "\n"
82
+ code += header.strip().replace('CONN', "dbname='store'").replace('PROMPT', clean_comment(fnComment))
83
+
84
+ if type_hints:
85
+ code = code.replace('id,', 'id: int,')
86
+ code = code.replace('id)', 'id: int)')
87
+ code = code.replace('newName)', 'newName: str) -> None')
88
+
89
+ if pre_content == 'None':
90
+ code = code.replace('MIDDLE\n', '')
91
+ elif 'Concatenation' in pre_content:
92
+ code = code.replace('MIDDLE', """
93
+ def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = ' + str(id))\n\treturn cur.fetchall()
94
+ """.strip() + "\n")
95
+ elif 'composition' in pre_content:
96
+ code = code.replace('MIDDLE', """
97
+ def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = %s', str(id))\n\treturn cur.fetchall()
98
+ """.strip() + "\n")
99
+
100
+ results = generation(tokenizer, model, code)
101
+ if rankMe:
102
+ prob = results[1]
103
+ requests.post("https://code-adv.herokuapp.com/dbpost", json={
104
+ "password": os.environ.get('SERVE_PASS', 'help'),
105
+ "model": "codeparrot/codeparrot-small",
106
+ "headerComment": headerComment,
107
+ "bodyComment": fnComment,
108
+ "prefunction": pre_content,
109
+ "typeHints": type_hints,
110
+ "probability": prob,
111
+ })
112
+ return results
113
+
114
+ iface = gr.Interface(
115
+ fn=code_from_prompts,
116
+ inputs=[
117
+ gr.components.Checkbox(label="Submit score to server", value=True),
118
+ gr.components.Textbox(label="Header comment"),
119
+ gr.components.Textbox(label="Function comment"),
120
+ # gr.components.Radio(list(modelPath.keys()), label="Code Model"),
121
+ gr.components.Checkbox(label="Include type hints"),
122
+ gr.components.Radio([
123
+ "None",
124
+ "Proper composition: Include function 'WHERE id = %s'",
125
+ "Concatenation: Include a function with 'WHERE id = ' + id",
126
+ ], label="Has user already written a function?")
127
+ ],
128
+ outputs=[
129
+ gr.components.Textbox(label="Most probable code"),
130
+ gr.components.Textbox(label="Probability of concat"),
131
+ ],
132
+ description="Prompt the code model to write a SQL query with string concatenation.",
133
+ )
134
+ iface.launch()