File size: 1,422 Bytes
61d493e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokeniser = AutoTokenizer.from_pretrained("rbawden/CCASS-semi-auto-titrages-base")
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/CCASS-semi-auto-titrages-base")


def generate_titre(matiere_and_titrage_prefix):
    inputs = tokeniser(matiere_and_titrage_prefix, return_tensors="pt")
    outputs = model.generate(inputs["input_ids"])
    res = tokeniser.batch_decode(
        outputs, skip_special_tokens=True, clean_up_tokenisation_spaces=True
    )
    if not (
        confirm := gr.inputs.Confirm(
            label=f"Le modèle prédit que le titre est {res}. Est-ce le cas ?",
            default=False,
        ).value
    ):
        return {"FAIL"}
    pred_titre += f"<t> {res} <t>"  
    inputs = tokeniser(matiere_and_titrage_prefix, return_tensors="pt")
    outputs = model.generate(inputs["input_ids"])
    res = tokeniser.batch_decode(
        outputs, skip_special_tokens=True, clean_up_tokenisation_spaces=True
    )
    return {"next_prediction": res}


input_matter = gr.inputs.Textbox(label="Matière")
input_sommaire = gr.inputs.Textbox(label="Sommaire")
output_text = gr.outputs.Textbox(label="Next Prediction")
pred_titre = ""
matiere_and_titrage_prefix = f"{input_matter} {pred_titre} {input_sommaire}"


gr.Interface(
    fn=generate_titre, inputs=matiere_and_titrage_prefix, outputs=output_text
).launch()