import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
model_id = 'ksabeh/gavi' | |
max_input_length = 512 | |
max_target_length = 10 | |
model = T5ForConditionalGeneration.from_pretrained(model_id) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
def predict(title, category): | |
input = f"{title} <hl> {category} <hl>" | |
model_input = tokenizer(input, max_length=max_input_length, truncation=True, | |
padding="max_length") | |
model_input = {k:torch.unsqueeze(torch.tensor(v),dim=0) for k,v in model_input.items()} | |
predictions = model.generate(**model_input, num_beams=8, do_sample=True, max_length=10) | |
return tokenizer.batch_decode(predictions, skip_special_tokens=True)[0] | |
iface = gr.Interface( | |
predict, | |
inputs=["text", "text"], | |
outputs=['text'], | |
title="Attribute Generation", | |
) | |
iface.launch() |