Spaces:
Sleeping
Sleeping
from transformers import T5ForConditionalGeneration, AutoTokenizer | |
import gradio as gr | |
# Load the model and tokenizer | |
model_name = "ejschwartz/hext5" # Replace with your desired model | |
model = T5ForConditionalGeneration.from_pretrained(model) | |
tokenizer = AutoTokenizer.from_pretrained(model) | |
# predict summary | |
def predict_summary(tokenizer,code): | |
input = tokenizer('summarize: '+code,return_tensors='pt',max_length=max_input_length,truncation=True) | |
output = model.generate(**input,max_new_tokens=256)[0] | |
return tokenizer.decode(output,skip_special_tokens=True) | |
# predict identifier (func name) | |
def predict_identifier(tokenizer,code): | |
''' | |
code should be like: "unsigned __int8 *__cdecl <func>(int *<var_0>,...){ return <func_1>(1);}" | |
''' | |
input = tokenizer('identifier_predict: '+code,return_tensors='pt',max_length=max_input_length,truncation=True) | |
output = model.generate(**input,max_new_tokens=250)[0] | |
return tokenizer.decode(output) | |
# Define the inference function | |
def generate_text(prompt): | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(**inputs, max_length=100) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=predict_identifier, | |
inputs="text", | |
outputs="text", | |
title="Predict identifiers", | |
description="Enter a prompt and see the model generate text." | |
) | |
# Launch the interface | |
iface.launch() | |
gr.load("models/ejschwartz/hext5").launch() |