Spaces:
Runtime error
Runtime error
Commit
·
190f1a6
1
Parent(s):
5bc4404
Update app.py
Browse files
app.py
CHANGED
@@ -28,6 +28,7 @@ preloadModels = {}
|
|
28 |
for m in list(modelPath.keys()):
|
29 |
preloadModels[m] = ecco.from_pretrained(modelPath[m])
|
30 |
|
|
|
31 |
rankings = requests.get("https://code-adv.herokuapp.com/db").json()['results']
|
32 |
|
33 |
def generation(tokenizer, model, content):
|
@@ -61,7 +62,11 @@ def generation(tokenizer, model, content):
|
|
61 |
prob = 0
|
62 |
for opt in seek_token_ids:
|
63 |
prob += next_words(content, len(tokenizer(content)['input_ids']), opt)
|
64 |
-
return [
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def clean_comment(txt):
|
67 |
return txt.replace("\\", "").replace("\n", " ")
|
@@ -73,10 +78,11 @@ def code_from_prompts(
|
|
73 |
# model,
|
74 |
type_hints,
|
75 |
pre_content):
|
76 |
-
tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
|
77 |
# model = ecco.from_pretrained(modelPath[model])
|
78 |
# model = preloadModels[model]
|
79 |
-
|
|
|
80 |
|
81 |
code = clean_comment(headerComment) + "\n"
|
82 |
code += header.strip().replace('CONN', "dbname='store'").replace('PROMPT', clean_comment(fnComment))
|
@@ -99,7 +105,7 @@ def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = %s', st
|
|
99 |
|
100 |
results = generation(tokenizer, model, code)
|
101 |
if rankMe:
|
102 |
-
prob = results[1]
|
103 |
requests.post("https://code-adv.herokuapp.com/dbpost", json={
|
104 |
"password": os.environ.get('SERVE_PASS', 'help'),
|
105 |
"model": "codeparrot/codeparrot-small",
|
@@ -116,7 +122,7 @@ iface = gr.Interface(
|
|
116 |
inputs=[
|
117 |
gr.components.Checkbox(label="Submit score to server", value=True),
|
118 |
gr.components.Textbox(label="Header comment"),
|
119 |
-
gr.components.Textbox(label="Function comment"),
|
120 |
# gr.components.Radio(list(modelPath.keys()), label="Code Model"),
|
121 |
gr.components.Checkbox(label="Include type hints"),
|
122 |
gr.components.Radio([
|
@@ -128,6 +134,7 @@ iface = gr.Interface(
|
|
128 |
outputs=[
|
129 |
gr.components.Textbox(label="Most probable code"),
|
130 |
gr.components.Textbox(label="Probability of concat"),
|
|
|
131 |
],
|
132 |
description="Prompt the code model to write a SQL query with string concatenation.",
|
133 |
)
|
|
|
28 |
for m in list(modelPath.keys()):
|
29 |
preloadModels[m] = ecco.from_pretrained(modelPath[m])
|
30 |
|
31 |
+
topComments = []
|
32 |
rankings = requests.get("https://code-adv.herokuapp.com/db").json()['results']
|
33 |
|
34 |
def generation(tokenizer, model, content):
|
|
|
62 |
prob = 0
|
63 |
for opt in seek_token_ids:
|
64 |
prob += next_words(content, len(tokenizer(content)['input_ids']), opt)
|
65 |
+
return [
|
66 |
+
"".join(full_output.tokens),
|
67 |
+
str(prob.item() * 100),
|
68 |
+
rankings
|
69 |
+
]
|
70 |
|
71 |
def clean_comment(txt):
|
72 |
return txt.replace("\\", "").replace("\n", " ")
|
|
|
78 |
# model,
|
79 |
type_hints,
|
80 |
pre_content):
|
81 |
+
# tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
|
82 |
# model = ecco.from_pretrained(modelPath[model])
|
83 |
# model = preloadModels[model]
|
84 |
+
tokenizer = AutoTokenizer.from_pretrained(modelPath["CodeParrot-small"])
|
85 |
+
model = preloadModels["CodeParrot-small"]
|
86 |
|
87 |
code = clean_comment(headerComment) + "\n"
|
88 |
code += header.strip().replace('CONN', "dbname='store'").replace('PROMPT', clean_comment(fnComment))
|
|
|
105 |
|
106 |
results = generation(tokenizer, model, code)
|
107 |
if rankMe:
|
108 |
+
prob = float(results[1])
|
109 |
requests.post("https://code-adv.herokuapp.com/dbpost", json={
|
110 |
"password": os.environ.get('SERVE_PASS', 'help'),
|
111 |
"model": "codeparrot/codeparrot-small",
|
|
|
122 |
inputs=[
|
123 |
gr.components.Checkbox(label="Submit score to server", value=True),
|
124 |
gr.components.Textbox(label="Header comment"),
|
125 |
+
gr.components.Textbox(label="Function comment", label="Top injection comments: " + ",".join(topComments)),
|
126 |
# gr.components.Radio(list(modelPath.keys()), label="Code Model"),
|
127 |
gr.components.Checkbox(label="Include type hints"),
|
128 |
gr.components.Radio([
|
|
|
134 |
outputs=[
|
135 |
gr.components.Textbox(label="Most probable code"),
|
136 |
gr.components.Textbox(label="Probability of concat"),
|
137 |
+
gr.components.Json(value=rankings)
|
138 |
],
|
139 |
description="Prompt the code model to write a SQL query with string concatenation.",
|
140 |
)
|