georeactor commited on
Commit
eb4b465
·
1 Parent(s): d94d42b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -19
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
2
  import torch
 
3
  import requests
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
 
6
  header = """
7
  import psycopg2
@@ -10,7 +12,7 @@ conn = psycopg2.connect("CONN")
10
  cur = conn.cursor()
11
 
12
  MIDDLE
13
- def rename_customer(id, new_name):
14
  # PROMPT
15
  cur.execute("UPDATE customer SET name
16
  """
@@ -25,36 +27,43 @@ modelPath = {
25
  }
26
 
27
  def generation(tokenizer, model, content):
28
- input_ids = tokenizer.encode(content, return_tensors='pt')
29
  decoder = 'Standard'
30
  num_beams = 2 if decoder == 'Beam' else None
31
  typical_p = 0.8 if decoder == 'Typical' else None
32
  do_sample = (decoder in ['Beam', 'Typical', 'Sample'])
33
 
34
- typ_output = model.generate(
35
- input_ids,
36
- max_length=120,
37
- num_beams=num_beams,
38
- early_stopping=True,
39
- do_sample=do_sample,
40
- typical_p=typical_p,
41
- repetition_penalty=4.0,
42
- )
43
- txt = tokenizer.decode(typ_output[0], skip_special_tokens=True)
44
-
45
- prob = 0.5
46
- return [txt, prob]
 
 
 
 
 
 
 
 
47
 
48
  def code_from_prompts(prompt, model, type_hints, pre_content):
49
  tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
50
- model = AutoModelForCausalLM.from_pretrained(modelPath[model])
51
 
52
  code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt)
53
 
54
  if type_hints:
55
  code = code.replace('id,', 'id: int,')
56
  code = code.replace('id)', 'id: int)')
57
- code = code.replace('new_name)', 'new_name: str) -> None')
58
 
59
  if pre_content == 'None':
60
  code = code.replace('MIDDLE\n', '')
@@ -94,4 +103,4 @@ iface = gr.Interface(
94
  ],
95
  description="Prompt the code model to write a SQL query with string concatenation.",
96
  )
97
- iface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import ecco
4
  import requests
5
+ from transformers import AutoTokenizer
6
+ from torch.nn import functional as F
7
 
8
  header = """
9
  import psycopg2
 
12
  cur = conn.cursor()
13
 
14
  MIDDLE
15
+ def rename_customer(id, newName):
16
  # PROMPT
17
  cur.execute("UPDATE customer SET name
18
  """
 
27
  }
28
 
29
  def generation(tokenizer, model, content):
 
30
  decoder = 'Standard'
31
  num_beams = 2 if decoder == 'Beam' else None
32
  typical_p = 0.8 if decoder == 'Typical' else None
33
  do_sample = (decoder in ['Beam', 'Typical', 'Sample'])
34
 
35
+ seek_token_ids = tokenizer.encode('= " +')[1:]
36
+
37
+ full_output = lm.generate(code, generate=10, do_sample=False)
38
+
39
+ def next_words(code, position, seek_token_ids):
40
+ op_model = lm.generate(code, generate=1, do_sample=False)
41
+ hidden_states = op_model.hidden_states
42
+ layer_no = len(hidden_states) - 1
43
+ h = hidden_states[-1]
44
+ hidden_state = h[position - 1]
45
+ logits = op_model.lm_head(op_model.to(hidden_state))
46
+ softmax = F.softmax(logits, dim=-1)
47
+ my_token_prob = softmax[seek_token_ids[0]]
48
+
49
+ if len(seek_token_ids) > 1:
50
+ newprompt = code + tokenizer.decode(seek_token_ids[0])
51
+ return my_token_prob * next_words(newprompt, position + 1, seek_token_ids[1:])
52
+ return my_token_prob
53
+
54
+ prob = next_words(content, len(tokenizer(content)['input_ids']), seek_token_ids)
55
+ return ["".join(full_output.tokens), prob]
56
 
57
  def code_from_prompts(prompt, model, type_hints, pre_content):
58
  tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
59
+ model = ecco.from_pretrained(modelPath[model])
60
 
61
  code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt)
62
 
63
  if type_hints:
64
  code = code.replace('id,', 'id: int,')
65
  code = code.replace('id)', 'id: int)')
66
+ code = code.replace('newName)', 'newName: str) -> None')
67
 
68
  if pre_content == 'None':
69
  code = code.replace('MIDDLE\n', '')
 
103
  ],
104
  description="Prompt the code model to write a SQL query with string concatenation.",
105
  )
106
+ iface.launch()