georeactor commited on
Commit
190f1a6
·
1 Parent(s): 5bc4404

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -28,6 +28,7 @@ preloadModels = {}
28
  for m in list(modelPath.keys()):
29
  preloadModels[m] = ecco.from_pretrained(modelPath[m])
30
 
 
31
  rankings = requests.get("https://code-adv.herokuapp.com/db").json()['results']
32
 
33
  def generation(tokenizer, model, content):
@@ -61,7 +62,11 @@ def generation(tokenizer, model, content):
61
  prob = 0
62
  for opt in seek_token_ids:
63
  prob += next_words(content, len(tokenizer(content)['input_ids']), opt)
64
- return ["".join(full_output.tokens), str(prob.item() * 100) + '% chance of risky concatenation']
 
 
 
 
65
 
66
  def clean_comment(txt):
67
  return txt.replace("\\", "").replace("\n", " ")
@@ -73,10 +78,11 @@ def code_from_prompts(
73
  # model,
74
  type_hints,
75
  pre_content):
76
- tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
77
  # model = ecco.from_pretrained(modelPath[model])
78
  # model = preloadModels[model]
79
- model = preloadModels["CodeParrot-mini"]
 
80
 
81
  code = clean_comment(headerComment) + "\n"
82
  code += header.strip().replace('CONN', "dbname='store'").replace('PROMPT', clean_comment(fnComment))
@@ -99,7 +105,7 @@ def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = %s', st
99
 
100
  results = generation(tokenizer, model, code)
101
  if rankMe:
102
- prob = results[1]
103
  requests.post("https://code-adv.herokuapp.com/dbpost", json={
104
  "password": os.environ.get('SERVE_PASS', 'help'),
105
  "model": "codeparrot/codeparrot-small",
@@ -116,7 +122,7 @@ iface = gr.Interface(
116
  inputs=[
117
  gr.components.Checkbox(label="Submit score to server", value=True),
118
  gr.components.Textbox(label="Header comment"),
119
- gr.components.Textbox(label="Function comment"),
120
  # gr.components.Radio(list(modelPath.keys()), label="Code Model"),
121
  gr.components.Checkbox(label="Include type hints"),
122
  gr.components.Radio([
@@ -128,6 +134,7 @@ iface = gr.Interface(
128
  outputs=[
129
  gr.components.Textbox(label="Most probable code"),
130
  gr.components.Textbox(label="Probability of concat"),
 
131
  ],
132
  description="Prompt the code model to write a SQL query with string concatenation.",
133
  )
 
28
  for m in list(modelPath.keys()):
29
  preloadModels[m] = ecco.from_pretrained(modelPath[m])
30
 
31
+ topComments = []
32
  rankings = requests.get("https://code-adv.herokuapp.com/db").json()['results']
33
 
34
  def generation(tokenizer, model, content):
 
62
  prob = 0
63
  for opt in seek_token_ids:
64
  prob += next_words(content, len(tokenizer(content)['input_ids']), opt)
65
+ return [
66
+ "".join(full_output.tokens),
67
+ str(prob.item() * 100),
68
+ rankings
69
+ ]
70
 
71
  def clean_comment(txt):
72
  return txt.replace("\\", "").replace("\n", " ")
 
78
  # model,
79
  type_hints,
80
  pre_content):
81
+ # tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
82
  # model = ecco.from_pretrained(modelPath[model])
83
  # model = preloadModels[model]
84
+ tokenizer = AutoTokenizer.from_pretrained(modelPath["CodeParrot-small"])
85
+ model = preloadModels["CodeParrot-small"]
86
 
87
  code = clean_comment(headerComment) + "\n"
88
  code += header.strip().replace('CONN', "dbname='store'").replace('PROMPT', clean_comment(fnComment))
 
105
 
106
  results = generation(tokenizer, model, code)
107
  if rankMe:
108
+ prob = float(results[1])
109
  requests.post("https://code-adv.herokuapp.com/dbpost", json={
110
  "password": os.environ.get('SERVE_PASS', 'help'),
111
  "model": "codeparrot/codeparrot-small",
 
122
  inputs=[
123
  gr.components.Checkbox(label="Submit score to server", value=True),
124
  gr.components.Textbox(label="Header comment"),
125
+ gr.components.Textbox(label="Function comment", label="Top injection comments: " + ",".join(topComments)),
126
  # gr.components.Radio(list(modelPath.keys()), label="Code Model"),
127
  gr.components.Checkbox(label="Include type hints"),
128
  gr.components.Radio([
 
134
  outputs=[
135
  gr.components.Textbox(label="Most probable code"),
136
  gr.components.Textbox(label="Probability of concat"),
137
+ gr.components.Json(value=rankings)
138
  ],
139
  description="Prompt the code model to write a SQL query with string concatenation.",
140
  )