File size: 9,405 Bytes
020655e
16558a6
020655e
 
c9893e0
020655e
 
9549ed6
 
020655e
 
 
 
 
 
 
 
d619933
6d43ab6
e5b04ea
 
 
9549ed6
2513305
 
 
9549ed6
2513305
 
 
 
e5b04ea
 
 
dc8f7ad
e5b04ea
8f97a6d
 
9549ed6
2513305
 
 
82b8822
2513305
 
 
020655e
e5b04ea
8f97a6d
 
16558a6
 
3bfad63
16558a6
 
f1db64d
16558a6
 
 
 
 
 
 
 
 
 
 
 
 
 
f1db64d
 
 
16558a6
f1db64d
020655e
f1db64d
16558a6
f1db64d
16558a6
 
 
c894b6e
020655e
 
 
c2845f7
020655e
225a6ee
4b6d551
16558a6
c251a4b
 
16558a6
 
ef1b90d
16558a6
 
6169c4e
3d89a99
fc696d7
16558a6
 
 
225a6ee
15cbf39
225a6ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
020655e
16558a6
 
 
 
c251a4b
 
16558a6
 
ef1b90d
16558a6
 
 
6169c4e
fc696d7
16558a6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import streamlit as st
from annotated_text import annotated_text
from transformers import pipeline
from PIL import Image
import re

st.sidebar.header("**Instructions**")
st.sidebar.markdown("Démonstrateur du modèle [NERmemBERT](https://hf.co/collections/CATIE-AQ/french-ner-pack-658aefafe3f7a2dcf0e4dbb4) entraîné sur 385 000 à 420 000 données en français en fonction de la configuration. Il est capable d'étiquetter les entités LOC (Localisations), PER (Personnalités), ORG (Organisations) et MISC (Divers). Il est disponible en huit versions : NERmemBERT1-3entities-base (110M de paramètres, contexte de 512 tokens), NERmemBERT2-3entities (111M, 1024 tokens), NERmemBERTa-3entities (111M, 1024 tokens), NERmemBERT1-3entities-large (336M, 512 tokens), NERmemBERT1-4entities-base (110M, 512 tokens), NERmemBERT2-4entities (111M, 1024 tokens), NERmemBERTa-4entities (111M, 1024 tokens), NERmemBERT1-4entities-large (336M, 512 tokens). Pour utiliser l'application, sélectionnez la version de votre choix ci-dessous, puis renseignez un texte. Enfin appuyez sur le bouton « Appliquer le modèle » pour observer la réponse trouvée par le modèle. Pour en savoir plus sur ces modèles, vous pouvez lire l'[article de blog](https://blog.vaniila.ai/NER/) détaillant la démarche suvie.")
version = st.sidebar.radio("Choix de la version du modèle :", ["NERmemBERT1-3entities-base", "NERmemBERT2-3entities", "NERmemBERTa-3entities", "NERmemBERT1-3entities-large","NERmemBERT1-4entities-base", "NERmemBERT2-4entities", "NERmemBERTa-4entities", "NERmemBERT1-4entities-large"])
st.sidebar.markdown("---")
st.sidebar.markdown("Ce modèle a été entraîné via la plateforme [*Vaniila*](https://www.vaniila.ai/) du [CATIE](https://www.catie.fr/).")

image_path = 'Vaniila.png'
image = Image.open(image_path)
st.sidebar.image(image, caption=None, width=None, use_column_width=None, clamp=False, channels="RGB", output_format="auto")

@st.cache_resource
def load_model(version,text):
    if version == "NERmemBERT1-3entities-base":
        ner = pipeline('token-classification', model='CATIE-AQ/NERmembert-base-3entities', tokenizer='CATIE-AQ/NERmembert-base-3entities', aggregation_strategy="simple")
        result = ner(text)
        return result   
    if version == "NERmemBERT2-3entities":
        ner = pipeline('token-classification', model='CATIE-AQ/NERmembert2-3entities', tokenizer='CATIE-AQ/NERmembert2-3entities', aggregation_strategy="simple")
        result = ner(text)
        return result   
    if version == "NERmemBERTa-3entities":
        ner = pipeline('token-classification', model='CATIE-AQ/NERmemberta-3entities', tokenizer='CATIE-AQ/NERmemberta-3entities', aggregation_strategy="simple")
        result = ner(text)
        return result   
    if version == "NERmemBERT1-3entities-large":
        ner = pipeline('token-classification', model='CATIE-AQ/NERmembert-large-3entities', tokenizer='CATIE-AQ/NERmembert-large-3entities', aggregation_strategy="simple")
        result = ner(text)
        return result   
    if version == "NERmemBERT1-4entities-base":
        ner = pipeline('token-classification', model='CATIE-AQ/NERmembert-base-4entities', tokenizer='CATIE-AQ/NERmembert-base-4entities', aggregation_strategy="simple")
        result = ner(text)
        return result   
    if version == "NERmemBERT2-4entities":
        ner = pipeline('token-classification', model='CATIE-AQ/NERmembert2-4entities', tokenizer='CATIE-AQ/NERmembert2-4entities', aggregation_strategy="simple")
        result = ner(text)
        return result   
    if version == "NERmemBERTa-4entities":
        ner = pipeline('token-classification', model='CATIE-AQ/NERmemberta-4entities', tokenizer='CATIE-AQ/NERmemberta-4entities', aggregation_strategy="simple")
        result = ner(text)
        return result   
    else:
        ner = pipeline('token-classification', model='CATIE-AQ/NERmembert-large-4entities', tokenizer='CATIE-AQ/NERmembert-large-4entities', aggregation_strategy="simple")
        result = ner(text)
        return result   

def getcolor(texts, labels):
    colors = {'LOC': '#38419D', 'PER': '#BF3131', 'ORG': '#597E52', 'MISC':'#F1C232'}
    return [(t,l,colors[l]) for t, l in zip(texts, labels)]

def color_annotation(to_print,text) :  
    text_ner = []
    label_ner = []
    for i in range(len(to_print)) :
        text_ner.append(to_print[i]["word"])
        label_ner.append(to_print[i]["entity_group"])

    anns = getcolor(text_ner, label_ner)
    anns = list(set(anns))
    text_ner = list(set(text_ner))
    text_ner = list(sorted(text_ner, key = len))

    for i in range(len(anns)):
        for j in range(len(text_ner)):
            if text_ner[j] == anns[i][0]:
                 text = text.replace(text_ner[j],str(anns[i]))

    for i in re.findall(r"\((.*?)\)", text) : # pour gérer les cas de mots inclus dans des n_grams
        if "(" in i:
            text = text.replace(i+")",i.split(', ')[0][2:-1])
            
    text = text.replace(")",')","').replace(')","","',')","').replace("(",'","(').replace('","","(','","(').replace("'-","-")

    return text



st.markdown("<h2 style='text-align: center'>NERmembert", unsafe_allow_html=True)    
st.markdown("<h4 style='text-align: center'>"+version, unsafe_allow_html=True)    
option = st.selectbox(
    'Choix du mode',
    ('Texte libre', 'Exemple 1', 'Exemple 2'))

if option == "Exemple 1":
    text = st.text_area("Votre texte", value="Le dévoilement du logo officiel des JO s'est déroulé le 21 octobre 2019 au Grand Rex. Ce nouvel emblème et cette nouvelle typographie ont été conçus par le designer Sylvain Boyer avec les agences Royalties & Ecobranding. Rond, il rassemble trois symboles : une médaille d'or, la flamme olympique et Marianne, symbolisée par un visage de femme mais privée de son bonnet phrygien caractéristique. La typographie dessinée fait référence à l'Art déco, mouvement artistique des années 1920, décennie pendant laquelle ont eu lieu pour la dernière fois les Jeux olympiques à Paris en 1924. Pour la première fois, ce logo sera unique pour les Jeux olympiques et les Jeux paralympiques.",height=175)
    if text:
        to_print = load_model(version,text)
        display = color_annotation(to_print,text)
        list_to_display = [] # pour pouvoir afficher la couleur, on doit passer les mots à colorier de str en tuple
        for i in range(len(display.split('","'))):
            if "#" in display.split('","')[i]:
                list_to_display.append(eval(display.split('","')[i]))
            else :
                list_to_display.append(display.split('","')[i])     
        annotated_text(*list_to_display)
        st.write("\n")
        with st.expander("Afficher le score pour chacune des entitées trouvées :"):
            for i in range(len(to_print)) :
                st.write("- Score pour que ",to_print[i]["word"]," soit de type", to_print[i]["entity_group"]," : ",round(to_print[i]["score"],3))

elif option == "Exemple 2":
    text = st.text_area("Votre texte", value="Assurés de disputer l'Euro 2024 en Allemagne l'été prochain (du 14 juin au 14 juillet) depuis leur victoire aux Pays-Bas, les Bleus ont fait le nécessaire pour avoir des certitudes. Avec six victoires en six matchs officiels et un seul but encaissé, Didier Deschamps a consolidé les acquis de la dernière Coupe du monde de football. Les joueurs clés sont connus : Kylian Mbappé, Aurélien Tchouameni, Antoine Griezmann, Ibrahima Konaté ou encore Mike Maignan.",height=175)
    if text:
        to_print = load_model(version,text)
        display = color_annotation(to_print,text)
        list_to_display = [] # pour pouvoir afficher la couleur, on doit passer les mots à colorier de str en tuple
        for i in range(len(display.split('","'))):
            if "#" in display.split('","')[i]:
                list_to_display.append(eval(display.split('","')[i]))
            else :
                list_to_display.append(display.split('","')[i])     
        annotated_text(*list_to_display)
        st.write("\n")
        with st.expander("Afficher le score pour chacune des entitées trouvées :"):
            for i in range(len(to_print)) :
                st.write("- Score pour que ",to_print[i]["word"]," soit de type", to_print[i]["entity_group"]," : ",round(to_print[i]["score"],3))
                
else:
    text = st.text_area("Votre texte", value="",height=175)
    if text:
        col1, col2, col3 = st.columns(3)
        if col2.button('Appliquer le modèle'):
            to_print = load_model(version,text)
            display = color_annotation(to_print,text)
            list_to_display = [] # pour pouvoir afficher la couleur, on doit passer les mots à colorier de str en tuple
            for i in range(len(display.split('","'))):
                if "#" in display.split('","')[i]:
                    list_to_display.append(eval(display.split('","')[i]))
                else :
                    list_to_display.append(display.split('","')[i])
            annotated_text(*list_to_display)
            st.write("\n")
            with st.expander("Afficher le score pour chacune des entitées trouvées :"):
                for i in range(len(to_print)) :
                    st.write("- Score pour que ",to_print[i]["word"]," soit de type", to_print[i]["entity_group"]," : ",round(to_print[i]["score"],3))