Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
tokeniser = AutoTokenizer.from_pretrained("rbawden/CCASS-semi-auto-titrages-base") | |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/CCASS-semi-auto-titrages-base") | |
def generate_titre(matiere_and_titrage_prefix): | |
inputs = tokeniser(matiere_and_titrage_prefix, return_tensors="pt") | |
outputs = model.generate(inputs["input_ids"]) | |
res = tokeniser.batch_decode( | |
outputs, skip_special_tokens=True, clean_up_tokenisation_spaces=True | |
) | |
if not ( | |
confirm := gr.inputs.Confirm( | |
label=f"Le modèle prédit que le titre est {res}. Est-ce le cas ?", | |
default=False, | |
).value | |
): | |
return {"FAIL"} | |
pred_titre += f"<t> {res} <t>" | |
inputs = tokeniser(matiere_and_titrage_prefix, return_tensors="pt") | |
outputs = model.generate(inputs["input_ids"]) | |
res = tokeniser.batch_decode( | |
outputs, skip_special_tokens=True, clean_up_tokenisation_spaces=True | |
) | |
return {"next_prediction": res} | |
input_matter = gr.inputs.Textbox(label="Matière") | |
input_sommaire = gr.inputs.Textbox(label="Sommaire") | |
output_text = gr.outputs.Textbox(label="Next Prediction") | |
pred_titre = "" | |
matiere_and_titrage_prefix = f"{input_matter} {pred_titre} {input_sommaire}" | |
gr.Interface( | |
fn=generate_titre, inputs=matiere_and_titrage_prefix, outputs=output_text | |
).launch() | |