resym / app.py
ejschwartz's picture
Try examples again
0d90edb
raw
history blame
4.12 kB
import gradio as gr
import json
import os
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import huggingface_hub
import prep_decompiled
description = """# ReSym Test Space
This is a test space of the models from the [ReSym
artifacts](https://github.com/lt-asset/resym). Sadly, at the time I am writing
this, not all of ReSym is publicly available; specifically, the Prolog component
is [not available](https://github.com/lt-asset/resym/issues/2).
This space simply performs inference on the two pretrained models available as
part of the ReSym artifacts. It takes a variable name and some decompiled code
as input, and outputs the variable type and other information.
## Disclaimer
I'm not a ReSym developer and I may have messed something up. In particular,
you must prompt the variable names in the decompiled code as part of the prompt,
and I reused some of their own code to do this.
## Todo
* Add field decoding (probably needs Docker)
"""
hf_key = os.environ["HF_TOKEN"]
huggingface_hub.login(token=hf_key)
tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase-3b")
vardecoder_model = AutoModelForCausalLM.from_pretrained(
"ejschwartz/resym-vardecoder", torch_dtype=torch.bfloat16, device_map="auto"
)
fielddecoder_model = AutoModelForCausalLM.from_pretrained(
"ejschwartz/resym-fielddecoder", torch_dtype=torch.bfloat16, device_map="auto"
)
example = r"""__int64 __fastcall sub_410D81(__int64 a1, __int64 a2, __int64 a3)
{
int v4; // [rsp+20h] [rbp-20h] BYREF
__int64 v5; // [rsp+28h] [rbp-18h]
if ( !a1 || !a2 || !a3 )
return 0LL;
v4 = 5;
v5 = a3;
return sub_411142(a1, a2, &v4);
}"""
examples = [
ex.encode().decode("unicode_escape") for ex in open("examples.txt", "r").readlines()
]
@spaces.GPU
def infer(code):
splitcode = [s.strip() for s in code.splitlines()]
code = "\n".join(splitcode)
bodyvars = [
v["name"] for v in prep_decompiled.extract_comments(splitcode) if "name" in v
]
argvars = [
v["name"] for v in prep_decompiled.parse_signature(splitcode) if "name" in v
]
vars = argvars + bodyvars
# comments = prep_decompiled.extract_comments(splitcode)
# sig = prep_decompiled.parse_signature(splitcode)
# print(f"vars {vars}")
varstring = ", ".join([f"`{v}`" for v in vars])
var_name = vars[0]
# ejs: Yeah, this var_name thing is really bizarre. But look at https://github.com/lt-asset/resym/blob/main/training_src/fielddecoder_inf.py
var_prompt = f"What are the original name and data types of variables {varstring}?\n```\n{code}\n```{var_name}"
print(f"Prompt:\n{var_prompt}")
input_ids = tokenizer.encode(var_prompt, return_tensors="pt").cuda()[
:, : 8192 - 1024
]
var_output = vardecoder_model.generate(
input_ids=input_ids,
max_new_tokens=1024,
num_beams=4,
num_return_sequences=1,
do_sample=False,
early_stopping=False,
pad_token_id=0,
eos_token_id=0,
)[0]
var_output = tokenizer.decode(
var_output[input_ids.size(1) :],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
field_output = fielddecoder_model.generate(
input_ids=input_ids,
max_new_tokens=1024,
num_beams=4,
num_return_sequences=1,
do_sample=False,
early_stopping=False,
pad_token_id=0,
eos_token_id=0,
)[0]
field_output = tokenizer.decode(
field_output[input_ids.size(1) :],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
var_output = var_name + ":" + var_output
field_output = var_name + ":" + field_output
return var_output, varstring
demo = gr.Interface(
fn=infer,
inputs=[
gr.Textbox(lines=10, value=example, label="Hex-Rays Decompilation"),
],
outputs=[
gr.Text(label="Var Decoder Output"),
# gr.Text(label="Field Decoder Output"),
gr.Text(label="Generated Variable List"),
],
description=description,
examples=examples,
)
demo.launch()