import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM tokeniser = AutoTokenizer.from_pretrained("rbawden/CCASS-semi-auto-titrages-base") model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/CCASS-semi-auto-titrages-base") def generate_titre(matiere_and_titrage_prefix): inputs = tokeniser(matiere_and_titrage_prefix, return_tensors="pt") outputs = model.generate(inputs["input_ids"]) res = tokeniser.batch_decode( outputs, skip_special_tokens=True, clean_up_tokenisation_spaces=True ) if not ( confirm := gr.inputs.Confirm( label=f"Le modèle prédit que le titre est {res}. Est-ce le cas ?", default=False, ).value ): return {"FAIL"} pred_titre += f" {res} " inputs = tokeniser(matiere_and_titrage_prefix, return_tensors="pt") outputs = model.generate(inputs["input_ids"]) res = tokeniser.batch_decode( outputs, skip_special_tokens=True, clean_up_tokenisation_spaces=True ) return {"next_prediction": res} input_matter = gr.inputs.Textbox(label="Matière") input_sommaire = gr.inputs.Textbox(label="Sommaire") output_text = gr.outputs.Textbox(label="Next Prediction") pred_titre = "" matiere_and_titrage_prefix = f"{input_matter} {pred_titre} {input_sommaire}" gr.Interface( fn=generate_titre, inputs=matiere_and_titrage_prefix, outputs=output_text ).launch()