resym / app.py
ejschwartz's picture
Why did I think indentation was removed?
5f12c87
import gradio as gr
from gradio_client import Client
from gradio_client.exceptions import AppError
import frontmatter
import os
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import huggingface_hub
import prep_decompiled
hf_key = os.environ["HF_TOKEN"]
huggingface_hub.login(token=hf_key)
tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase-3b")
vardecoder_model = AutoModelForCausalLM.from_pretrained(
"ejschwartz/resym-vardecoder", torch_dtype=torch.bfloat16
).to("cuda")
fielddecoder_model = AutoModelForCausalLM.from_pretrained(
"ejschwartz/resym-fielddecoder", torch_dtype=torch.bfloat16
).to("cuda")
gradio_client = Client("https://ejschwartz-resym-field-helper.hf.space/")
examples = [
ex.encode().decode("unicode_escape") for ex in open("examples.txt", "r").readlines()
]
# Example prompt
# "input": "```\n_BOOL8 __fastcall sub_409B9A(_QWORD *a1, _QWORD *a2)\n{\nreturn *a1 < *a2 || *a1 == *a2 && a1[1] < a2[1];\n}\n```\nWhat are the variable name and type for the following memory accesses:a1, a1[1], a2, a2[1]?\n",
# "output": "a1: a, os_reltime* -> sec, os_time_t\na1[1]: a, os_reltime* -> usec, os_time_t\na2: b, os_reltime* -> sec, os_time_t\na2[1]: b, os_reltime* -> usec, os_time_t",
def field_prompt(code):
try:
field_helper_result = gradio_client.predict(
decompiled_code=code,
api_name="/predict",
)
except AppError as e:
print(f"AppError: {e}")
return None, [], None
print(f"field helper result: {field_helper_result}")
fields = sorted(list(set([e['expr'] for e in field_helper_result[0] if e['expr'] != ''])))
print(f"fields: {fields}")
prompt = f"```\n{code}\n```\nWhat are the variable name and type for the following memory accesses:{', '.join(fields)}?\n"
if len(fields) > 0:
prompt += f"{fields[0]}:"
print(f"field prompt: {repr(prompt)}")
return prompt, fields, field_helper_result
@spaces.GPU
def infer(code):
splitcode = code.splitlines()
#splitcode = [s.strip() for s in code.splitlines()]
#code = "\n".join(splitcode)
bodyvars = [
v["name"] for v in prep_decompiled.extract_comments(splitcode) if "name" in v
]
argvars = [
v["name"] for v in prep_decompiled.parse_signature(splitcode) if "name" in v
]
vars = argvars + bodyvars
# comments = prep_decompiled.extract_comments(splitcode)
# sig = prep_decompiled.parse_signature(splitcode)
# print(f"vars {vars}")
varstring = ", ".join([f"`{v}`" for v in vars])
first_var = vars[0]
# ejs: Yeah, this var_name thing is really bizarre. But look at https://github.com/lt-asset/resym/blob/main/training_src/fielddecoder_inf.py
var_prompt = f"What are the original name and data types of variables {varstring}?\n```\n{code}\n```{first_var}:"
print(f"Prompt:\n{repr(var_prompt)}")
var_input_ids = tokenizer.encode(var_prompt, return_tensors="pt").cuda()[
:, : 8192 - 1024
]
var_output = vardecoder_model.generate(
input_ids=var_input_ids,
max_new_tokens=1024,
num_beams=4,
num_return_sequences=1,
do_sample=False,
early_stopping=False,
pad_token_id=0,
eos_token_id=0,
)[0]
var_output = tokenizer.decode(
var_output[var_input_ids.size(1) :],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
field_prompt_result, fields, field_helper_result = field_prompt(code)
if len(fields) == 0:
field_output = "Failed to parse fields" if field_prompt_result is None else "No fields"
else:
field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").cuda()[
:, : 8192 - 1024
]
field_output = fielddecoder_model.generate(
input_ids=field_input_ids,
max_new_tokens=1024,
num_beams=4,
num_return_sequences=1,
do_sample=False,
early_stopping=False,
pad_token_id=0,
eos_token_id=0,
)[0]
field_output = tokenizer.decode(
field_output[field_input_ids.size(1) :],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
field_output = fields[0] + ":" + field_output
var_output = first_var + ":" + var_output
fieldstring = ", ".join(fields)
return var_output, field_output, varstring, fieldstring
demo = gr.Interface(
fn=infer,
inputs=[
gr.Textbox(lines=10, value=examples[0], label="Hex-Rays Decompilation"),
],
outputs=[
gr.Text(label="Var Decoder Output"),
gr.Text(label="Field Decoder Output"),
gr.Text(label="Generated Variable List"),
gr.Text(label="Generated Field Access List"),
],
# description=frontmatter.load("README.md").content,
description="""This is a test space of the models from the [ReSym
artifacts](https://github.com/lt-asset/resym). For more information, please see
[the README](https://huggingface.co/spaces/ejschwartz/resym/blob/main/README.md).""",
examples=examples,
)
demo.launch()