ejschwartz commited on
Commit
0b155f0
·
1 Parent(s): 0adad70
Files changed (1) hide show
  1. app.py +23 -7
app.py CHANGED
@@ -76,11 +76,10 @@ def infer(code):
76
  # ejs: Yeah, this var_name thing is really bizarre. But look at https://github.com/lt-asset/resym/blob/main/training_src/fielddecoder_inf.py
77
  prompt = f"What are the original name and data types of variables {varstring}?\n```{code}\n```{var_name}"
78
 
79
- prompt = code + var_name + ":"
80
- print(prompt)
81
 
82
  input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, : 8192 - 1024]
83
- output = vardecoder_model.generate(
84
  input_ids=input_ids,
85
  max_new_tokens=1024,
86
  num_beams=4,
@@ -90,14 +89,30 @@ def infer(code):
90
  pad_token_id=0,
91
  eos_token_id=0,
92
  )[0]
93
- output = tokenizer.decode(
94
- output[input_ids.size(1) :],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  skip_special_tokens=True,
96
  clean_up_tokenization_spaces=True,
97
  )
98
 
99
- output = var_name + ":" + output
100
- return output, varstring
 
101
 
102
 
103
  demo = gr.Interface(
@@ -106,6 +121,7 @@ demo = gr.Interface(
106
  gr.Textbox(lines=10, value=example),
107
  ],
108
  outputs=[gr.Text(label="Var Decoder Output"),
 
109
  gr.Text(label="Generated Variable List")],
110
  description=description
111
  )
 
76
  # ejs: Yeah, this var_name thing is really bizarre. But look at https://github.com/lt-asset/resym/blob/main/training_src/fielddecoder_inf.py
77
  prompt = f"What are the original name and data types of variables {varstring}?\n```{code}\n```{var_name}"
78
 
79
+ print(f"Prompt:\n{prompt}")
 
80
 
81
  input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, : 8192 - 1024]
82
+ var_output = vardecoder_model.generate(
83
  input_ids=input_ids,
84
  max_new_tokens=1024,
85
  num_beams=4,
 
89
  pad_token_id=0,
90
  eos_token_id=0,
91
  )[0]
92
+ var_output = tokenizer.decode(
93
+ var_output[input_ids.size(1) :],
94
+ skip_special_tokens=True,
95
+ clean_up_tokenization_spaces=True,
96
+ )
97
+ field_output = fielddecoder_model.generate(
98
+ input_ids=input_ids,
99
+ max_new_tokens=1024,
100
+ num_beams=4,
101
+ num_return_sequences=1,
102
+ do_sample=False,
103
+ early_stopping=False,
104
+ pad_token_id=0,
105
+ eos_token_id=0,
106
+ )[0]
107
+ field_output = tokenizer.decode(
108
+ field_output[input_ids.size(1) :],
109
  skip_special_tokens=True,
110
  clean_up_tokenization_spaces=True,
111
  )
112
 
113
+ var_output = var_name + ":" + var_output
114
+ field_output = var_name + ":" + field_output
115
+ return var_output, field_output, varstring
116
 
117
 
118
  demo = gr.Interface(
 
121
  gr.Textbox(lines=10, value=example),
122
  ],
123
  outputs=[gr.Text(label="Var Decoder Output"),
124
+ gr.Text(label="Field Decoder Output"),
125
  gr.Text(label="Generated Variable List")],
126
  description=description
127
  )