import gradio as gr from gradio_client import Client from gradio_client.exceptions import AppError import frontmatter import os import spaces import torch from transformers import AutoTokenizer, AutoModelForCausalLM import huggingface_hub import prep_decompiled hf_key = os.environ["HF_TOKEN"] huggingface_hub.login(token=hf_key) tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase-3b") vardecoder_model = AutoModelForCausalLM.from_pretrained( "ejschwartz/resym-vardecoder", torch_dtype=torch.bfloat16 ).to("cuda") fielddecoder_model = AutoModelForCausalLM.from_pretrained( "ejschwartz/resym-fielddecoder", torch_dtype=torch.bfloat16 ).to("cuda") gradio_client = Client("https://ejschwartz-resym-field-helper.hf.space/") examples = [ ex.encode().decode("unicode_escape") for ex in open("examples.txt", "r").readlines() ] # Example prompt # "input": "```\n_BOOL8 __fastcall sub_409B9A(_QWORD *a1, _QWORD *a2)\n{\nreturn *a1 < *a2 || *a1 == *a2 && a1[1] < a2[1];\n}\n```\nWhat are the variable name and type for the following memory accesses:a1, a1[1], a2, a2[1]?\n", # "output": "a1: a, os_reltime* -> sec, os_time_t\na1[1]: a, os_reltime* -> usec, os_time_t\na2: b, os_reltime* -> sec, os_time_t\na2[1]: b, os_reltime* -> usec, os_time_t", def field_prompt(code): try: field_helper_result = gradio_client.predict( decompiled_code=code, api_name="/predict", ) except AppError as e: print(f"AppError: {e}") return None, [], None print(f"field helper result: {field_helper_result}") fields = sorted(list(set([e['expr'] for e in field_helper_result[0] if e['expr'] != '']))) print(f"fields: {fields}") prompt = f"```\n{code}\n```\nWhat are the variable name and type for the following memory accesses:{', '.join(fields)}?\n" if len(fields) > 0: prompt += f"{fields[0]}:" print(f"field prompt: {repr(prompt)}") return prompt, fields, field_helper_result @spaces.GPU def infer(code): splitcode = code.splitlines() #splitcode = [s.strip() for s in code.splitlines()] #code = "\n".join(splitcode) bodyvars = [ v["name"] for v in prep_decompiled.extract_comments(splitcode) if "name" in v ] argvars = [ v["name"] for v in prep_decompiled.parse_signature(splitcode) if "name" in v ] vars = argvars + bodyvars # comments = prep_decompiled.extract_comments(splitcode) # sig = prep_decompiled.parse_signature(splitcode) # print(f"vars {vars}") varstring = ", ".join([f"`{v}`" for v in vars]) first_var = vars[0] # ejs: Yeah, this var_name thing is really bizarre. But look at https://github.com/lt-asset/resym/blob/main/training_src/fielddecoder_inf.py var_prompt = f"What are the original name and data types of variables {varstring}?\n```\n{code}\n```{first_var}:" print(f"Prompt:\n{repr(var_prompt)}") var_input_ids = tokenizer.encode(var_prompt, return_tensors="pt").cuda()[ :, : 8192 - 1024 ] var_output = vardecoder_model.generate( input_ids=var_input_ids, max_new_tokens=1024, num_beams=4, num_return_sequences=1, do_sample=False, early_stopping=False, pad_token_id=0, eos_token_id=0, )[0] var_output = tokenizer.decode( var_output[var_input_ids.size(1) :], skip_special_tokens=True, clean_up_tokenization_spaces=True, ) field_prompt_result, fields, field_helper_result = field_prompt(code) if len(fields) == 0: field_output = "Failed to parse fields" if field_prompt_result is None else "No fields" else: field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").cuda()[ :, : 8192 - 1024 ] field_output = fielddecoder_model.generate( input_ids=field_input_ids, max_new_tokens=1024, num_beams=4, num_return_sequences=1, do_sample=False, early_stopping=False, pad_token_id=0, eos_token_id=0, )[0] field_output = tokenizer.decode( field_output[field_input_ids.size(1) :], skip_special_tokens=True, clean_up_tokenization_spaces=True, ) field_output = fields[0] + ":" + field_output var_output = first_var + ":" + var_output fieldstring = ", ".join(fields) return var_output, field_output, varstring, fieldstring demo = gr.Interface( fn=infer, inputs=[ gr.Textbox(lines=10, value=examples[0], label="Hex-Rays Decompilation"), ], outputs=[ gr.Text(label="Var Decoder Output"), gr.Text(label="Field Decoder Output"), gr.Text(label="Generated Variable List"), gr.Text(label="Generated Field Access List"), ], # description=frontmatter.load("README.md").content, description="""This is a test space of the models from the [ReSym artifacts](https://github.com/lt-asset/resym). For more information, please see [the README](https://huggingface.co/spaces/ejschwartz/resym/blob/main/README.md).""", examples=examples, ) demo.launch()