ccass-titrage / app.py
maurya's picture
Add application file
61d493e
raw
history blame
1.42 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokeniser = AutoTokenizer.from_pretrained("rbawden/CCASS-semi-auto-titrages-base")
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/CCASS-semi-auto-titrages-base")
def generate_titre(matiere_and_titrage_prefix):
inputs = tokeniser(matiere_and_titrage_prefix, return_tensors="pt")
outputs = model.generate(inputs["input_ids"])
res = tokeniser.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenisation_spaces=True
)
if not (
confirm := gr.inputs.Confirm(
label=f"Le modèle prédit que le titre est {res}. Est-ce le cas ?",
default=False,
).value
):
return {"FAIL"}
pred_titre += f"<t> {res} <t>"
inputs = tokeniser(matiere_and_titrage_prefix, return_tensors="pt")
outputs = model.generate(inputs["input_ids"])
res = tokeniser.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenisation_spaces=True
)
return {"next_prediction": res}
input_matter = gr.inputs.Textbox(label="Matière")
input_sommaire = gr.inputs.Textbox(label="Sommaire")
output_text = gr.outputs.Textbox(label="Next Prediction")
pred_titre = ""
matiere_and_titrage_prefix = f"{input_matter} {pred_titre} {input_sommaire}"
gr.Interface(
fn=generate_titre, inputs=matiere_and_titrage_prefix, outputs=output_text
).launch()